Source code for vbi.models.cupy.base

"""
Base class for VBI CuPy models providing a unified interface for parameter management.

This base class provides a consistent API across all CuPy-based neural mass models,
ensuring consistency in parameter access, documentation, and display.
"""
from typing import Dict, List, Any
import numpy as np


[docs] class BaseCupyModel: """ Abstract base class for all VBI CuPy models. This class provides a unified interface for model parameter management, ensuring consistency across different CuPy model implementations. CuPy models typically store parameters directly as instance attributes (e.g., self.G, self.dt, self.sigma) and may also maintain a par_ dict. This base class provides a Python-level interface for accessing and documenting these parameters. Attributes ---------- par_ : dict Dictionary storing all model parameters. valid_params : list List of valid parameter names for this model. """
[docs] def __init__(self): """Initialize the base CuPy model.""" self.par_ = {} self.valid_params = []
[docs] def get_default_parameters(self) -> Dict[str, Any]: """ Get the default parameters for the model. Returns ------- dict Dictionary containing default parameter values. Examples -------- >>> model = GHB_sde({"G": 25.0}) >>> defaults = model.get_default_parameters() >>> print(defaults['G']) 25.0 """ raise NotImplementedError("Subclass must implement get_default_parameters()")
[docs] def get_parameter_descriptions(self) -> Dict[str, tuple]: """ Get descriptions for all model parameters. Returns ------- dict Dictionary mapping parameter names to (description, type) tuples. Type should be one of: 'float', 'int', 'str', 'bool', 'ndarray'. Examples -------- >>> model = GHB_sde({"G": 25.0}) >>> descriptions = model.get_parameter_descriptions() >>> print(descriptions['G']) ('Global coupling strength', 'float') """ raise NotImplementedError("Subclass must implement get_parameter_descriptions()")
[docs] def check_parameters(self, par: Dict[str, Any]) -> None: """ Validate that all provided parameters are valid for this model. Parameters ---------- par : dict Dictionary of parameters to validate. Raises ------ AssertionError If any parameter name is not in the valid_params list. Examples -------- >>> model = GHB_sde({"G": 25.0}) >>> model.check_parameters({'G': 30.0}) # Valid >>> model.check_parameters({'invalid_param': 1.0}) # Raises AssertionError """ for key in par.keys(): assert key in self.valid_params, f"Invalid parameter: {key}"
[docs] def get_parameters(self) -> Dict[str, Any]: """ Get all model parameters as a dictionary. Returns ------- dict Dictionary containing all current model parameters. Examples -------- >>> model = GHB_sde({"G": 25.0, "dt": 0.01}) >>> params = model.get_parameters() >>> print(params['G']) 25.0 """ if hasattr(self, 'par_') and self.par_: # Return a copy to prevent external modification return dict(self.par_) # Fallback: construct from valid_params params = {} for param_name in self.valid_params: if hasattr(self, param_name): value = getattr(self, param_name) # Convert cupy arrays to numpy for display/serialization try: import cupy as cp if isinstance(value, cp.ndarray): value = cp.asnumpy(value) except ImportError: pass params[param_name] = value return params
[docs] def get_parameter(self, name: str) -> Any: """ Get the value of a specific parameter. Parameters ---------- name : str Name of the parameter to retrieve. Returns ------- Any Value of the requested parameter. Raises ------ AttributeError If parameter name does not exist. Examples -------- >>> model = GHB_sde({"G": 25.0}) >>> g_value = model.get_parameter('G') >>> print(g_value) 25.0 """ if hasattr(self, 'par_') and name in self.par_: return self.par_[name] if hasattr(self, name): value = getattr(self, name) # Convert cupy arrays to numpy for display try: import cupy as cp if isinstance(value, cp.ndarray): return cp.asnumpy(value) except ImportError: pass return value raise AttributeError(f"Parameter '{name}' not found.")
[docs] def list_parameters(self) -> List[str]: """ Get a list of all valid parameter names for this model. Returns ------- list List of valid parameter names. Examples -------- >>> model = GHB_sde({"G": 25.0}) >>> params = model.list_parameters() >>> print('G' in params) True """ return list(self.valid_params)
def _format_value(self, value: Any) -> str: """ Format a parameter value for display in the table. Parameters ---------- value : Any Parameter value to format. Returns ------- str Formatted string representation. """ if value is None: return "None" elif isinstance(value, np.ndarray): if value.size == 1: return f"{value.item()}" else: return f"shape {value.shape}" elif isinstance(value, (list, tuple)) and len(value) > 3: return f"length {len(value)}" elif isinstance(value, (int, float, np.integer, np.floating)): return f"{value}" elif isinstance(value, str): return f'"{value}"' elif isinstance(value, bool): return f"{value}" else: # Check for cupy arrays try: import cupy as cp if isinstance(value, cp.ndarray): if value.size == 1: return f"{cp.asnumpy(value).item()}" else: return f"shape {value.shape}" except ImportError: pass return str(value) def _format_parameters_table(self, model_name: str = None) -> str: """ Format model parameters as a table with names, descriptions, values, and types. Parameters ---------- model_name : str, optional Custom name to display for the model. If None, uses self.__class__.__name__. Returns ------- str Formatted table string with 4 columns: - Parameter: parameter name - Description: what the parameter does - Value: current value or shape - Type: float | int | str | bool | ndarray | - """ param_info = self.get_parameter_descriptions() current_params = self.get_parameters() # Use provided model_name or default to class name display_name = model_name if model_name is not None else self.__class__.__name__ lines = [ "=" * 110, f"{display_name}", "=" * 110, "", "Model Parameters:", "-" * 110, f"{'Parameter':<15} | {'Description':<40} | {'Value/Shape':<30} | {'Type':<15}", "-" * 110, ] for name in sorted(self.valid_params): if name in current_params: current_value = current_params[name] # Get description and type from param_info if isinstance(param_info.get(name), tuple): description, param_type = param_info[name] else: description = param_info.get(name, "No description") param_type = "-" # Format current value for display current_str = self._format_value(current_value) # Truncate long strings if len(description) > 40: description = description[:37] + "..." if len(current_str) > 30: current_str = current_str[:27] + "..." lines.append(f"{name:<15} | {description:<40} | {current_str:<30} | {param_type:<15}") lines.append("=" * 110) return "\n".join(lines) def __str__(self) -> str: """ Return string representation of the model with parameter table. Returns ------- str Formatted string with model information and parameters table. """ return self._format_parameters_table() def __repr__(self) -> str: """ Return detailed string representation of the model. Returns ------- str String representation showing class name and number of parameters. """ return f"{self.__class__.__name__}(n_params={len(self.valid_params)})"