Source code for scitex_decorators._xarray_fn

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-04-30 15:41:19 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/decorators/_xarray_fn.py
# ----------------------------------------
import os

__FILE__ = "./src/scitex/decorators/_xarray_fn.py"
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------
from functools import wraps
from typing import Any as _Any
from typing import Callable

import numpy as np

from ._converters import is_nested_decorator


[docs] def xarray_fn(func: Callable) -> Callable: @wraps(func) def wrapper(*args: _Any, **kwargs: _Any) -> _Any: # Skip conversion if already in a nested decorator context if is_nested_decorator(): results = func(*args, **kwargs) return results # Set the current decorator context wrapper._current_decorator = "xarray_fn" # Store original object for type preservation original_object = args[0] if args else None # Convert args to xarray DataArrays def to_xarray(data): import pandas as pd import torch import xarray as xr if isinstance(data, xr.DataArray): return data elif isinstance(data, np.ndarray): return xr.DataArray(data) elif isinstance(data, list): return xr.DataArray(data) elif hasattr(data, "__class__") and data.__class__.__name__ == "Tensor": return xr.DataArray(data.detach().cpu().numpy()) elif hasattr(data, "__class__") and data.__class__.__name__ == "DataFrame": return xr.DataArray(data.values) elif hasattr(data, "__class__") and data.__class__.__name__ == "Series": return xr.DataArray(data.values) else: return xr.DataArray([data]) converted_args = [to_xarray(arg) for arg in args] converted_kwargs = {k: to_xarray(v) for k, v in kwargs.items()} # Assertion to ensure all args are converted to xarray DataArrays import xarray as xr for arg_index, arg in enumerate(converted_args): assert isinstance( arg, xr.DataArray ), f"Argument {arg_index} not converted to DataArray: {type(arg)}" results = func(*converted_args, **converted_kwargs) # Convert results back to original input types import xarray as xr if isinstance(results, xr.DataArray): if original_object is not None: if isinstance(original_object, list): return results.values.tolist() elif isinstance(original_object, np.ndarray): return results.values elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "Tensor" ): import torch return torch.tensor(results.values) elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "DataFrame" ): import pandas as pd return pd.DataFrame(results.values) elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "Series" ): import pandas as pd return pd.Series(results.values.flatten()) return results return results # Mark as a wrapper for detection wrapper._is_wrapper = True wrapper._decorator_type = "xarray_fn" return wrapper
# EOF