Source code for scitex_decorators._torch_fn

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-04-30 15:40:43 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/decorators/_torch_fn.py
# ----------------------------------------
import os

__FILE__ = "./src/scitex/decorators/_torch_fn.py"
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------

from functools import wraps
from typing import Any as _Any
from typing import Callable

import numpy as np

from ._converters import _return_always, is_nested_decorator, to_torch


[docs] def torch_fn(func: Callable) -> Callable: """Decorator for PyTorch function compatibility. Automatically converts inputs to PyTorch tensors and handles various data types gracefully. Preserves the original input type in the output. Features -------- - Converts inputs to PyTorch tensors - Preserves scalar parameters (int, float, bool, str) - Preserves dimension tuples like dim=(0, 1) - Handles nested lists/tuples gracefully - Automatically converts axis to dim for torch functions - Applies device="cuda" if available - Returns output in same type as input (numpy->numpy, pandas->pandas, etc.) Parameters ---------- func : Callable The function to decorate Returns ------- Callable The decorated function Examples -------- >>> @torch_fn ... def mean_squared(x, dim=None): ... return (x ** 2).mean(dim=dim) >>> >>> # Works with numpy arrays >>> result = mean_squared(np.array([1, 2, 3])) >>> >>> # Works with nested lists >>> result = mean_squared([[1, 2], [3, 4]]) >>> >>> # Preserves dimension tuples >>> result = mean_squared(data, dim=(0, 1)) Notes ----- For optimal performance with batch processing, apply torch_fn before batch_fn: @batch_fn @torch_fn def my_function(x): ... Or use auto-ordering to handle this automatically. """ @wraps(func) def wrapper(*args: _Any, **kwargs: _Any) -> _Any: # Skip conversion if already in a nested decorator context if is_nested_decorator(): results = func(*args, **kwargs) return results # Set the current decorator context wrapper._current_decorator = "torch_fn" # Store original object for type preservation original_object = args[0] if args else None converted_args, converted_kwargs = to_torch( *args, return_fn=_return_always, **kwargs ) # Skip strict assertion for certain types that may not convert to tensors # Instead, convert what we can and pass through what we can't import torch validated_args = [] for arg_index, arg in enumerate(converted_args): if isinstance(arg, torch.Tensor): validated_args.append(arg) elif isinstance(arg, (int, float, str, type(None))): # Pass through scalars and strings unchanged validated_args.append(arg) elif isinstance(arg, list) and all( isinstance(item, torch.Tensor) for item in arg ): # List of tensors - pass through as is validated_args.append(arg) else: # Try one more conversion attempt try: validated_args.append(torch.tensor(arg).float()) except: # If all else fails, pass through unchanged validated_args.append(arg) results = func(*validated_args, **converted_kwargs) # Convert results back to original input types import torch if isinstance(results, torch.Tensor): if original_object is not None: if isinstance(original_object, list): return results.detach().cpu().numpy().tolist() elif isinstance(original_object, np.ndarray): return results.detach().cpu().numpy() elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "DataFrame" ): import pandas as pd return pd.DataFrame(results.detach().cpu().numpy()) elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "Series" ): import pandas as pd return pd.Series(results.detach().cpu().numpy().flatten()) elif ( hasattr(original_object, "__class__") and original_object.__class__.__name__ == "DataArray" ): import xarray as xr return xr.DataArray(results.detach().cpu().numpy()) return results return results # Mark as a wrapper for detection wrapper._is_wrapper = True wrapper._decorator_type = "torch_fn" return wrapper
# EOF