Source code for vbi.models.numba.base

"""
Base class for VBI Numba models providing a unified interface for parameter management.

This base class provides a consistent API across all Numba-based neural mass models,
similar to the BaseModel class used for C++ models. However, it accounts for the
specific architecture of Numba models which use jitclass for efficient parameter storage.
"""
from typing import Dict, List, Any
import numpy as np


[docs] class BaseNumbaModel: """ Abstract base class for all VBI Numba models. This class provides a unified interface for model parameter management, ensuring consistency across different Numba model implementations. Unlike the C++ BaseModel, Numba models typically use a jitclass (e.g., ParVEP, ParJR) for storing parameters that will be used in JIT-compiled functions. This base class provides a Python-level interface while the actual parameters may be stored in the jitclass instance (typically self.P or self.par). Attributes ---------- P : jitclass instance Internal jitclass instance storing all model parameters for Numba compilation. valid_params : list List of valid parameter names for this model. """
[docs] def __init__(self): """Initialize the base Numba model.""" self.P = None # Will be set by subclass to jitclass instance self.valid_params = [] self._checked = False
[docs] def get_default_parameters(self) -> Dict[str, Any]: """ Get the default parameters for the model. Returns ------- dict Dictionary containing default parameter values. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> defaults = model.get_default_parameters() >>> print(defaults['G']) 1.0 """ raise NotImplementedError("Subclass must implement get_default_parameters()")
[docs] def get_parameter_descriptions(self) -> Dict[str, str]: """ Get descriptions for all model parameters. Returns ------- dict Dictionary mapping parameter names to their descriptions. Can also return tuples of (description, type) where type is one of: 'scalar', 'vector', 'matrix', 'string', 'int', 'bool'. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> descriptions = model.get_parameter_descriptions() >>> print(descriptions['G']) ('Global coupling strength', 'scalar') """ raise NotImplementedError("Subclass must implement get_parameter_descriptions()")
[docs] def check_parameters(self, par: Dict[str, Any]) -> None: """ Validate that all provided parameters are valid for this model. Parameters ---------- par : dict Dictionary of parameters to validate. Raises ------ ValueError If any parameter name is not in the valid_params list. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> model.check_parameters({'G': 2.0}) # Valid >>> model.check_parameters({'invalid_param': 1.0}) # Raises ValueError """ invalid_params = [] for key in par.keys(): if key not in self.valid_params: invalid_params.append(key) if invalid_params: print(f"Invalid parameter(s): {', '.join(invalid_params)}") self.print_valid_parameters() raise ValueError(f"Invalid parameter(s): {', '.join(invalid_params)}")
[docs] def get_parameters(self) -> Dict[str, Any]: """ Get all model parameters as a dictionary. This method can be overridden by subclasses that have a _par_to_dict() helper method. By default, it reads parameters directly from self.P. Returns ------- dict Dictionary containing all current model parameters from the jitclass. Examples -------- >>> model = JR_sde({"weights": np.eye(2), "G": 2.0}) >>> params = model.get_parameters() >>> print(params['G']) 2.0 """ if self.P is None: return {} # If subclass has _par_to_dict, use it if hasattr(self, '_par_to_dict'): return self._par_to_dict() # Default fallback: read directly from self.P params = {} for param_name in self.valid_params: if hasattr(self.P, param_name): value = getattr(self.P, param_name) # Convert numpy arrays for consistency if isinstance(value, np.ndarray): params[param_name] = np.array(value) else: params[param_name] = value return params
[docs] def get_parameter(self, name: str) -> Any: """ Get the value of a specific parameter. Parameters ---------- name : str Name of the parameter to retrieve. Returns ------- Any Value of the requested parameter. Raises ------ AttributeError If parameter name does not exist. ValueError If the parameter object (jitclass) has not been initialized. Notes ----- This method allows reading any parameter that exists in the parameter object (self.P), including derived parameters like 'nn' (number of nodes) even though they are not in valid_params (user-settable parameters). Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> g_value = model.get_parameter('G') >>> print(g_value) 1.0 >>> nn_value = model.get_parameter('nn') # Derived parameter >>> print(nn_value) 2 """ if self.P is None: raise ValueError("Parameter object not initialized. Run __init__ first.") if not hasattr(self.P, name): raise AttributeError(f"Parameter '{name}' not found in parameter object.") return getattr(self.P, name)
[docs] def list_parameters(self) -> List[str]: """ Get a list of all valid parameter names for this model. Returns ------- list List of valid parameter names. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> params = model.list_parameters() >>> print('G' in params) True """ return list(self.valid_params)
[docs] def print_valid_parameters(self): """ Print all valid parameter names for this model. This is a helper method for debugging and documentation purposes. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> model.print_valid_parameters() Valid parameters: G, A, B, a, b, v0, vmax, r, mu, noise_amp, ... """ print("Valid parameters:") print(", ".join(self.valid_params))
[docs] def run(self, par: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]: """ Run the model simulation. Parameters ---------- par : dict, optional Dictionary of parameters to update before running. **kwargs Additional keyword arguments specific to the model. Returns ------- dict Dictionary containing simulation results (typically 't' for time and state variables like 'x', 'E', 'I', etc.). Examples -------- >>> model = JR_sde({"weights": np.eye(2), "dt": 0.01, "t_end": 100.0}) >>> result = model.run() >>> t, x = result['t'], result['x'] """ raise NotImplementedError("Subclass must implement run()")
def __call__(self) -> Dict[str, Any]: """ Return model parameters when instance is called. Returns ------- dict Dictionary of all model parameters. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> params = model() >>> print(params['G']) 1.0 """ return self.get_parameters() def _format_value(self, value): """ Format a parameter value for display. Parameters ---------- value : Any The parameter value to format. Returns ------- str Formatted string representation of the value. """ if hasattr(value, 'shape') and value is not None: # Array-like object try: shape_tuple = tuple(value.shape) if len(shape_tuple) == 0: # 0-d array (scalar stored as array) return f"{value}" else: return f"shape {shape_tuple}" except: return str(type(value).__name__) elif isinstance(value, (list, tuple)) and len(value) > 0: # List/tuple if len(value) <= 3: return f"{value}" else: return f"[{len(value)} items]" elif isinstance(value, dict): return f"dict({len(value)})" elif isinstance(value, (int, float, np.integer, np.floating)): return f"{value}" elif isinstance(value, str): return f'"{value}"' elif isinstance(value, bool): return f"{value}" else: # Other types return str(value) def _format_parameters_table(self, model_name: str = None) -> str: """ Format model parameters as a table with names, descriptions, values, and types. Parameters ---------- model_name : str, optional Custom name to display for the model. If None, uses self.__class__.__name__. Returns ------- str Formatted table string with 4 columns: - Parameter: parameter name - Description: what the parameter does - Value: current value or shape - Type: scalar | vector | matrix | string | bool | int | - """ param_info = self.get_parameter_descriptions() current_params = self.get_parameters() # Use provided model_name or default to class name display_name = model_name if model_name is not None else self.__class__.__name__ lines = [ "=" * 110, f"{display_name}", "=" * 110, "", "Model Parameters:", "-" * 110, f"{'Parameter':<15} | {'Description':<40} | {'Value/Shape':<30} | {'Type':<15}", "-" * 110, ] for name in sorted(self.valid_params): if name in current_params: current_value = current_params[name] # Get description and type from param_info if isinstance(param_info.get(name), tuple): description, param_type = param_info[name] else: description = param_info.get(name, "No description") param_type = "-" # Format current value for display current_str = self._format_value(current_value) # Truncate long strings if len(description) > 40: description = description[:37] + "..." if len(current_str) > 30: current_str = current_str[:27] + "..." lines.append(f"{name:<15} | {description:<40} | {current_str:<30} | {param_type:<15}") lines.append("=" * 110) return "\n".join(lines) def __str__(self) -> str: """ Return string representation of the model with parameter table. Returns ------- str Formatted string with model information and parameters table. """ return self._format_parameters_table() def __repr__(self) -> str: """ Return detailed string representation of the model. Returns ------- str String representation showing class name and number of parameters. """ return f"{self.__class__.__name__}(n_params={len(self.valid_params)})"