Source code for vbi.optional_deps

"""
Optional dependency handling for VBI.

This module provides utilities for gracefully handling optional dependencies
and provides informative error messages when they're missing.
"""

import importlib
import functools
from typing import Optional, Any, Callable


[docs] class OptionalDependencyError(ImportError): """Raised when an optional dependency is required but not available.""" pass
[docs] def optional_import(module_name: str, install_name: Optional[str] = None) -> Any: """ Import a module if available, otherwise return None. Parameters ---------- module_name : str Name of the module to import install_name : str, optional Name to use in installation instructions (if different from module_name) Returns ------- module or None The imported module if successful, None if not available """ try: return importlib.import_module(module_name) except ImportError: return None
[docs] def require_optional(module_name: str, install_name: Optional[str] = None, extra: Optional[str] = None) -> Any: """ Import a required optional dependency with helpful error message. Parameters ---------- module_name : str Name of the module to import install_name : str, optional Name to use in installation instructions extra : str, optional VBI extra that provides this dependency Returns ------- module The imported module Raises ------ OptionalDependencyError If the module cannot be imported """ try: return importlib.import_module(module_name) except ImportError as e: install_name = install_name or module_name extra_hint = f" or 'pip install vbi[{extra}]'" if extra else "" raise OptionalDependencyError( f"The '{install_name}' package is required for this functionality. " f"Install it with 'pip install {install_name}'{extra_hint}" ) from e
[docs] def requires_optional(*dependencies): """ Decorator to check for optional dependencies before function execution. Parameters ---------- *dependencies : tuples Each tuple should be (module_name, install_name, extra) Examples -------- >>> @requires_optional(('torch', 'torch', 'inference')) ... def inference_function(): ... import torch ... # function implementation ... pass """ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): for dep in dependencies: if len(dep) == 3: module_name, install_name, extra = dep elif len(dep) == 2: module_name, install_name = dep extra = None else: module_name = dep[0] install_name = None extra = None require_optional(module_name, install_name, extra) return func(*args, **kwargs) return wrapper return decorator
# Pre-import commonly used optional dependencies torch = optional_import('torch') cupy = optional_import('cupy') sbi = optional_import('sbi')
[docs] def check_torch_available(): """Check if PyTorch is available.""" return torch is not None
[docs] def check_sbi_available(): """Check if SBI is available.""" return sbi is not None
[docs] def check_cupy_available(): """Check if CuPy is available.""" return cupy is not None