Source code for scitex_decorators._numpy_fn

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-04-30 15:29:53 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/decorators/_numpy_fn.py
# ----------------------------------------
import os

__FILE__ = "./src/scitex/decorators/_numpy_fn.py"
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------

import numpy as np

THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/decorators/_numpy_fn.py"

from functools import wraps
from typing import Any as _Any
from typing import Callable

from ._converters import _return_always, is_nested_decorator, to_numpy


[docs] def numpy_fn(func: Callable) -> Callable: @wraps(func) def wrapper(*args: _Any, **kwargs: _Any) -> _Any: # Skip conversion if already in a nested decorator context if is_nested_decorator(): results = func(*args, **kwargs) return results # Set the current decorator context wrapper._current_decorator = "numpy_fn" # Store original object for type preservation original_object = args[0] if args else None converted_args, converted_kwargs = to_numpy( *args, return_fn=_return_always, **kwargs ) # Skip strict assertion for certain types that may not convert to arrays # Instead, convert what we can and pass through what we can't validated_args = [] for arg_index, arg in enumerate(converted_args): if isinstance(arg, np.ndarray): validated_args.append(arg) elif isinstance(arg, (int, float, str, type(None))): # Pass through scalars and strings unchanged validated_args.append(arg) elif isinstance(arg, list) and all( isinstance(item, np.ndarray) for item in arg ): # List of arrays - pass through as is validated_args.append(arg) else: # Try one more conversion attempt try: validated_args.append(np.array(arg)) except: # If all else fails, pass through unchanged validated_args.append(arg) results = func(*validated_args, **converted_kwargs) # Convert results back to original input types if isinstance(results, np.ndarray): if original_object is not None: if isinstance(original_object, list): return results.tolist() elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "Tensor" ): import torch return torch.tensor(results) elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "DataFrame" ): import pandas as pd return pd.DataFrame(results) elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "Series" ): import pandas as pd return pd.Series(results) return results return results # Mark as a wrapper for detection wrapper._is_wrapper = True wrapper._decorator_type = "numpy_fn" return wrapper
# EOF