Source code for scitex_decorators._batch_fn

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-05-01 09:18:26 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/decorators/_batch_fn.py
# ----------------------------------------
import os

__FILE__ = "./src/scitex/decorators/_batch_fn.py"
__DIR__ = os.path.dirname(__FILE__)
from functools import wraps

# ----------------------------------------
from typing import Any as _Any
from typing import Callable

import numpy as np
from tqdm import tqdm as _tqdm

from ._converters import is_nested_decorator


[docs] def batch_fn(func: Callable) -> Callable: @wraps(func) def wrapper(x: _Any, *args: _Any, **kwargs: _Any) -> _Any: # Skip batching if in a nested decorator context and batch_size is already set if is_nested_decorator() and "batch_size" in kwargs: return func(x, *args, **kwargs) # Set the current decorator context wrapper._current_decorator = "batch_fn" # Mark that batch_fn has been applied if not hasattr(wrapper, "_decorator_order"): wrapper._decorator_order = [] wrapper._decorator_order.append("batch_fn") batch_size = int(kwargs.pop("batch_size", 4)) if len(x) <= batch_size: # Only pass batch_size if the function accepts it import inspect try: sig = inspect.signature(func) if "batch_size" in sig.parameters: return func(x, *args, **kwargs, batch_size=batch_size) else: return func(x, *args, **kwargs) except: # Fallback for wrapped functions return func(x, *args, **kwargs) n_batches = (len(x) + batch_size - 1) // batch_size results = [] for i_batch in _tqdm(range(n_batches)): start = i_batch * batch_size end = min((i_batch + 1) * batch_size, len(x)) # Only pass batch_size if the function accepts it import inspect try: sig = inspect.signature(func) if "batch_size" in sig.parameters: batch_result = func( x[start:end], *args, **kwargs, batch_size=batch_size ) else: batch_result = func(x[start:end], *args, **kwargs) except: # Fallback for wrapped functions batch_result = func(x[start:end], *args, **kwargs) import torch if isinstance(batch_result, torch.Tensor): batch_result = batch_result.cpu() elif isinstance(batch_result, tuple): batch_result = tuple( val.cpu() if isinstance(val, torch.Tensor) else val for val in batch_result ) results.append(batch_result) import torch if isinstance(results[0], tuple): n_vars = len(results[0]) combined_results = [] for i_var in range(n_vars): # Check if this element is stackable (tensor/array) or should be kept as-is first_elem = results[0][i_var] if isinstance(first_elem, (torch.Tensor, np.ndarray)): # Stack tensors/arrays if isinstance(first_elem, torch.Tensor): if first_elem.ndim == 0: combined = torch.stack([res[i_var] for res in results]) else: combined = torch.vstack([res[i_var] for res in results]) else: combined = np.vstack([res[i_var] for res in results]) combined_results.append(combined) else: # For non-tensor elements (like lists), just take the first one # (assuming they're all the same across batches) combined_results.append(first_elem) return tuple(combined_results) elif isinstance(results[0], torch.Tensor): # Check if results are 0-D tensors (scalars) if results[0].ndim == 0: return torch.stack(results) else: return torch.vstack(results) elif isinstance(results[0], np.ndarray): # Handle numpy arrays if results[0].ndim == 0: return np.array(results) else: return np.vstack(results) elif isinstance(results[0], (int, float)): # Handle scalar results return np.array(results) if len(results) > 1 else results[0] else: # For lists and other types return sum(results, []) # Mark as a wrapper for detection wrapper._is_wrapper = True wrapper._decorator_type = "batch_fn" return wrapper
# EOF