Source code for vbi.models.numba.jansen_rit

"""Jansen-Rit neural mass model — Numba JIT-compiled backend.

Implements the Jansen-Rit (1995) model for a network of cortical columns
using Numba-accelerated stochastic differential equation integration.

Reference
---------
Jansen, B.H. & Rit, V.G. (1995). Electroencephalogram and visual evoked
potential generation in a mathematical model of coupled cortical columns.
*Biological Cybernetics*, 73(4), 357–366.
"""
import warnings
import numpy as np
from typing import Dict, Any
from numba import njit
from vbi.utils import pretty_dtype
from numba.experimental import jitclass
from numba import float64, boolean, int64
from numba.extending import register_jitable
from numba.core.errors import NumbaPerformanceWarning
from vbi.models.numba.base import BaseNumbaModel

warnings.simplefilter("ignore", category=NumbaPerformanceWarning)

# ---------------------------------------------------------------
# Helper utilities (broadcasting, seeding, initial state)
# ---------------------------------------------------------------

[docs] @register_jitable def set_seed_compat(x): np.random.seed(x)
@register_jitable def _as_1d_array_like(x, nn): """ Broadcast scalar to 1D array of length nn if needed. """ x_arr = np.array(x) if not isinstance(x, np.ndarray) else x if x_arr.ndim == 0: return np.ones(nn) * float(x_arr) if x_arr.ndim == 1 and x_arr.shape[0] == nn: return x_arr.astype(np.float64) raise ValueError("Parameter must be scalar or 1D array of length nn")
[docs] @njit def set_initial_state_jr(nn, seed=-1): """ Initial state for JR: stack 6*n vectors. Mirrors ranges similar to the CuPy reference implementation. """ if seed is not None and seed >= 0: set_seed_compat(seed) y0 = np.random.uniform(-1.0, 1.0, nn) # x y1 = np.random.uniform(-500.0, 500.0, nn) # y y2 = np.random.uniform(-50.0, 50.0, nn) # z y3 = np.random.uniform(-6.0, 6.0, nn) # x' y4 = np.random.uniform(-20.0, 20.0, nn) # y' y5 = np.random.uniform(-500.0, 500.0, nn) # z' y = np.zeros(6 * nn) y[:nn] = y0 y[nn:2*nn] = y1 y[2*nn:3*nn] = y2 y[3*nn:4*nn] = y3 y[4*nn:5*nn] = y4 y[5*nn:6*nn] = y5 return y
# --------------------------------------------------------------- # JR parameters as a jitclass (Numba-friendly container) # --------------------------------------------------------------- jr_spec = [ ("G", float64), ("A", float64), ("B", float64), ("a", float64), ("b", float64), ("v0", float64), ("vmax", float64), ("r", float64), ("mu", float64), ("noise_amp", float64), ("dt", float64), ("t_cut", float64), ("t_end", float64), ("decimate", int64), ("nn", int64), ("seed", int64), ("sigma", float64), # sqrt(dt) * noise_amp # arrays ("weights", float64[:, :]), ("C0", float64[:]), ("C1", float64[:]), ("C2", float64[:]), ("C3", float64[:]), ("initial_state", float64[:]), ] # Default parameter values - single source of truth JR_DEFAULTS = { 'G': 1.0, 'A': 3.25, 'B': 22.0, 'a': 0.1, 'b': 0.05, 'v0': 6.0, 'vmax': 0.005, 'r': 0.56, 'mu': 0.24, 'noise_amp': 0.01, 'dt': 0.01, 't_cut': 0.0, 't_end': 1000.0, 'decimate': 1, 'C0': 135.0, 'C1': 0.8 * 135.0, 'C2': 0.25 * 135.0, 'C3': 0.25 * 135.0, 'seed': -1, 'weights': None, # Required - must be provided by user 'initial_state': None, # Auto-generated if not provided }
[docs] @jitclass(jr_spec) class ParJR: """ Numba jitclass container for Jansen-Rit model parameters. This class holds all parameters needed for the Jansen-Rit neural mass model in a format optimized for Numba compilation. It stores both scalar parameters and array parameters like connectivity weights and coupling constants. Note: This is an internal class used by the JR_sde class. Users should not instantiate this class directly. Use JR_sde instead. Default values are defined in JR_DEFAULTS dictionary. """ def __init__( self, weights, G=JR_DEFAULTS['G'], A=JR_DEFAULTS['A'], B=JR_DEFAULTS['B'], a=JR_DEFAULTS['a'], b=JR_DEFAULTS['b'], v0=JR_DEFAULTS['v0'], vmax=JR_DEFAULTS['vmax'], r=JR_DEFAULTS['r'], mu=JR_DEFAULTS['mu'], noise_amp=JR_DEFAULTS['noise_amp'], dt=JR_DEFAULTS['dt'], t_cut=JR_DEFAULTS['t_cut'], t_end=JR_DEFAULTS['t_end'], decimate=JR_DEFAULTS['decimate'], C0=JR_DEFAULTS['C0'], C1=JR_DEFAULTS['C1'], C2=JR_DEFAULTS['C2'], C3=JR_DEFAULTS['C3'], seed=JR_DEFAULTS['seed'] ): self.weights = weights self.nn = len(weights) self.G = G self.A = A self.B = B self.a = a self.b = b self.v0 = v0 self.vmax = vmax self.r = r self.mu = mu self.noise_amp = noise_amp self.dt = dt self.t_cut = t_cut self.t_end = t_end self.decimate = decimate self.seed = seed # C arrays are now passed pre-processed from outside self.C0 = C0 self.C1 = C1 self.C2 = C2 self.C3 = C3 self.sigma = np.sqrt(dt) * noise_amp self.initial_state = np.zeros(6 * self.nn) # set by caller later
# --------------------------------------------------------------- # JR model equations + integrator (Numba-jitted) # ---------------------------------------------------------------
[docs] @register_jitable def S_sigmoid(x, vmax, r, v0): """Numerically stable sigmoid function to avoid overflow.""" z = r * (v0 - x) # Clip z to avoid overflow: exp(700) is near overflow limit z_clipped = np.clip(z, -700, 700) return vmax / (1.0 + np.exp(z_clipped))
[docs] @njit def f_jr(x, t, P): """ Compute the right-hand side of the Jansen-Rit differential equations. This function implements the deterministic part of the Jansen-Rit neural mass model equations for a network of coupled brain regions. Parameters ---------- x : np.ndarray Current state vector of shape (6*nn,) containing stacked arrays: [x, y, z, x', y', z'] where nn is the number of nodes. t : float Current time (not used in autonomous system). P : ParJR Parameter object containing all model parameters. Returns ------- np.ndarray Derivative vector of shape (6*nn,) containing dx/dt. """ nn = P.nn # Unpack state x0 = x[0*nn:1*nn] # x y0 = x[1*nn:2*nn] # y z0 = x[2*nn:3*nn] # z xp = x[3*nn:4*nn] # x' yp = x[4*nn:5*nn] # y' zp = x[5*nn:6*nn] # z' # Precompute constants Aa = P.A * P.a Bb = P.B * P.b aa = P.a * P.a bb = P.b * P.b # Coupling term: weights @ (y - z) couplings = S_sigmoid(P.weights.dot(y0 - z0), P.vmax, P.r, P.v0) # Allocate derivative dxdt = np.zeros_like(x) # Dynamics dxdt[0*nn:1*nn] = xp dxdt[1*nn:2*nn] = yp dxdt[2*nn:3*nn] = zp dxdt[3*nn:4*nn] = Aa * S_sigmoid(y0 - z0, P.vmax, P.r, P.v0) - 2.0 * P.a * xp - aa * x0 dxdt[4*nn:5*nn] = ( Aa * (P.mu + P.C1 * S_sigmoid(P.C0 * x0, P.vmax, P.r, P.v0) + P.G * couplings) - 2.0 * P.a * yp - aa * y0 ) dxdt[5*nn:6*nn] = Bb * P.C3 * S_sigmoid(P.C2 * x0, P.vmax, P.r, P.v0) - 2.0 * P.b * zp - bb * z0 return dxdt
[docs] @njit def heun_sde(x, t, P): """ Perform one step of Heun's method for stochastic differential equations. This implements the Heun scheme (also called improved Euler method) for integrating stochastic differential equations with additive noise. Parameters ---------- x : np.ndarray Current state vector of shape (6*nn,). t : float Current time. P : ParJR Parameter object containing dt, sigma, and other model parameters. Returns ------- np.ndarray Updated state vector after one integration step. """ nn = P.nn dt = P.dt # Stochastic drive on the y' block, sigma already includes sqrt(dt) dW = P.sigma * np.random.randn(nn) k1 = f_jr(x, t, P) x1 = x + dt * k1 x1[4*nn:5*nn] += dW k2 = f_jr(x1, t + dt, P) x = x + 0.5 * dt * (k1 + k2) x[4*nn:5*nn] += dW return x
# --------------------------------------------------------------- # Top-level integrate and driver class (mirrors mpr.py style) # ---------------------------------------------------------------
[docs] def integrate(P): """ Integrate the Jansen-Rit model over time. This function performs the main time integration loop for the Jansen-Rit model, using the Heun stochastic integration scheme. It includes options for burn-in period and decimation of the output. Parameters ---------- P : ParJR Parameter object containing all simulation parameters including initial_state, t_end, dt, t_cut, and decimate. Returns ------- dict Dictionary with keys: - 't': np.ndarray of time points (decimated) - 'x': np.ndarray of shape (n_steps, nn) containing the output time series (y - z) representing local field potentials """ nn = P.nn dt = P.dt dec = P.decimate # Ensure initial state is defined x = P.initial_state.copy() nt = int(P.t_end / dt) tspan = np.linspace(0.0, (nt - 1) * dt, nt) # Cut & decimate bookkeeping i_cut = int(np.searchsorted(tspan, P.t_cut, side='left')) n_keep = (nt - i_cut + (dec - 1)) // dec # Output: y - z ts = np.zeros(n_keep, dtype=np.float32) ys = np.zeros((n_keep, nn), dtype=np.float32) k = 0 for i in range(nt): t = tspan[i] x = heun_sde(x, t, P) if i >= i_cut and ((i - i_cut) % dec == 0): ts[k] = t y0 = x[1*nn:2*nn] z0 = x[2*nn:3*nn] ys[k, :] = (y0 - z0).astype(np.float32) k += 1 if k >= n_keep: break return {"t": ts, "x": ys}
[docs] class JR_sde(BaseNumbaModel): """ Numba implementation of the Jansen-Rit neural mass model. .. list-table:: Parameters :widths: 25 50 25 :header-rows: 1 * - Name - Explanation - Default Value * - `A` - Excitatory post synaptic potential amplitude. - 3.25 * - `B` - Inhibitory post synaptic potential amplitude. - 22.0 * - `a` - Inverse time constant of the excitatory postsynaptic potential (1/a = time constant). - 0.1 (time constant: 10.0) * - `b` - Inverse time constant of the inhibitory postsynaptic potential (1/b = time constant). - 0.05 (time constant: 20.0) * - `C0` - Average number of synapses between pyramidal cells and excitatory interneurons. If array-like, it should be of length `nn` (number of nodes). - 135.0 * - `C1` - Average number of synapses between excitatory interneurons and pyramidal cells. If array-like, it should be of length `nn`. - 0.8 * 135.0 * - `C2` - Average number of synapses between pyramidal cells and inhibitory interneurons. If array-like, it should be of length `nn`. - 0.25 * 135.0 * - `C3` - Average number of synapses between inhibitory interneurons and pyramidal cells. If array-like, it should be of length `nn`. - 0.25 * 135.0 * - `vmax` - Maximum firing rate of the sigmoid function. - 0.005 * - `v0` - Potential at half of maximum firing rate (inflection point of sigmoid). - 6.0 * - `r` - Slope of sigmoid function at `v0`. - 0.56 * - `G` - Global coupling strength scaling the network connections. - 1.0 * - `mu` - Mean input to the excitatory population (external drive). - 0.24 * - `noise_amp` - Amplitude of the stochastic noise applied to the excitatory population. - 0.01 * - `weights` - Structural connectivity matrix of shape (`nn`, `nn`). Must be provided. - None * - `dt` - Integration time step. - 0.01 * - `t_end` - End time of simulation. - 1000.0 * - `t_cut` - Time from which to start collecting output (burn-in period). - 0.0 * - `decimate` - Decimation factor for the output time series (every `decimate`-th point is saved). - 1 * - `seed` - Random seed for reproducible simulations. If -1 or None, no seeding is applied. - -1 * - `initial_state` - Initial state vector of shape (6*nn,). If None, random initial conditions are generated. - None Usage example (single simulation): >>> import numpy as np >>> from vbi.models.numba.jansen_rit import JR_sde >>> W = np.eye(2) # 2-node demo connectivity >>> jr = JR_sde({"weights": W, "dt": 0.01, "t_end": 200.0, "t_cut": 100.0, "decimate": 1}) >>> out = jr.run() >>> t, x = out["t"], out["x"] # x has shape (n_step, nn) Notes ----- The Jansen-Rit model describes the dynamics of a cortical column with three neural populations: - Pyramidal cells (main excitatory population) - Excitatory interneurons - Inhibitory interneurons The model equations are integrated using the Heun stochastic integration scheme. The output represents the difference between excitatory and inhibitory postsynaptic potentials (y - z), which corresponds to the local field potential that can be measured experimentally. """
[docs] def __init__(self, par_jr: dict): """ Initialize the Jansen-Rit model. Parameters ---------- par_jr : dict Dictionary containing model parameters. See class documentation for available parameters. The 'weights' parameter is required and must be a square connectivity matrix. """ super().__init__() # Define valid parameters from JR_DEFAULTS, excluding derived parameters # 'nn' is derived from weights.shape, 'sigma' is derived from sqrt(dt) * noise_amp self.valid_params = [k for k in JR_DEFAULTS.keys() if k not in ['nn', 'sigma']] self.check_parameters(par_jr) # Validate weights early and create parameter jitclass if "weights" not in par_jr or par_jr["weights"] is None: raise ValueError("'weights' must be provided (square connectivity matrix)") W = np.array(par_jr["weights"], dtype=np.float64) if W.ndim != 2 or W.shape[0] != W.shape[1]: raise ValueError("'weights' must be a square 2D array") nn = len(W) # Pre-process C parameters before passing to jitclass params = dict(par_jr) params["weights"] = W # Handle C parameters - broadcast them here outside jitclass for c_name in ["C0", "C1", "C2", "C3"]: if c_name in params: c_val = params[c_name] params[c_name] = _as_1d_array_like(c_val, nn) else: # Set defaults if c_name == "C0": params[c_name] = _as_1d_array_like(135.0, nn) elif c_name == "C1": params[c_name] = _as_1d_array_like(0.8*135.0, nn) elif c_name == "C2": params[c_name] = _as_1d_array_like(0.25*135.0, nn) elif c_name == "C3": params[c_name] = _as_1d_array_like(0.25*135.0, nn) # Create jitclass instance self.P = ParJR(**params) # Seed handling self.seed = int(self.P.seed) if self.seed >= 0: np.random.seed(self.seed) # Ensure initial state if "initial_state" in par_jr and par_jr["initial_state"] is not None: x0 = np.array(par_jr["initial_state"], dtype=np.float64) if x0.shape[0] != 6 * self.P.nn: raise ValueError("initial_state must have length 6*nn") self.P.initial_state = x0 else: self.P.initial_state = set_initial_state_jr(self.P.nn, self.seed) self._checked = False
[docs] def get_default_parameters(self) -> Dict[str, Any]: """ Get the default parameters for the Jansen-Rit model. Returns the default values from JR_DEFAULTS dictionary, which is the single source of truth for all default parameter values. Returns ------- dict Dictionary containing default parameter values. Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> defaults = model.get_default_parameters() >>> print(defaults['G']) 1.0 Notes ----- 'nn' and 'sigma' are derived parameters: - 'nn': Number of nodes, determined from weights.shape[0] - 'sigma': Noise standard deviation, computed as sqrt(dt) * noise_amp """ # Return a copy with derived parameters added defaults = JR_DEFAULTS.copy() defaults['nn'] = None # Derived from weights defaults['sigma'] = None # Derived from dt and noise_amp return defaults
[docs] def get_parameter_descriptions(self) -> Dict[str, str]: """ Get descriptions for all Jansen-Rit model parameters. Returns ------- dict Dictionary mapping parameter names to tuples of (description, type). Examples -------- >>> model = JR_sde({"weights": np.eye(2)}) >>> descriptions = model.get_parameter_descriptions() >>> print(descriptions['G']) ('Global coupling strength', 'scalar') """ return { 'G': ('Global coupling strength', 'scalar'), 'A': ('Excitatory EPSP amplitude', 'scalar'), 'B': ('Inhibitory IPSP amplitude', 'scalar'), 'a': ('Inverse time constant of EPSP', 'scalar'), 'b': ('Inverse time constant of IPSP', 'scalar'), 'v0': ('Potential at half max firing rate', 'scalar'), 'vmax': ('Maximum firing rate', 'scalar'), 'r': ('Slope of sigmoid at v0', 'scalar'), 'mu': ('Mean external input', 'scalar'), 'noise_amp': ('Noise amplitude', 'scalar'), 'weights': ('Structural connectivity matrix', 'matrix'), 'dt': ('Integration time step', 'scalar'), 't_end': ('End time of simulation', 'scalar'), 't_cut': ('Cut-off time for output', 'scalar'), 'decimate': ('Decimation factor for output', 'int'), 'C0': ('Synapses: pyramidal to excitatory', 'vector'), 'C1': ('Synapses: excitatory to pyramidal', 'vector'), 'C2': ('Synapses: pyramidal to inhibitory', 'vector'), 'C3': ('Synapses: inhibitory to pyramidal', 'vector'), 'seed': ('Random seed for reproducibility', 'int'), 'initial_state': ('Initial state vector (6*nn)', 'vector'), 'nn': ('Number of nodes', 'int'), }
def __str__(self) -> str: """ Return a string representation of the model parameters. Uses the base class implementation to provide a formatted parameter table. Returns ------- str Formatted string showing all model parameters and their values. """ return self._format_parameters_table()
[docs] def check_input(self): """ Validate model parameters. Raises ------ ValueError If any parameter values are invalid (e.g., t_cut >= t_end, decimate < 1, or dimension mismatches). """ if self.P.t_cut >= self.P.t_end: raise ValueError("t_cut must be less than t_end") if self.P.decimate < 1: raise ValueError("decimate must be >= 1") if self.P.nn != self.P.weights.shape[0]: raise ValueError("nn != weights.shape[0]") self._checked = True
[docs] def set_initial_state(self, seed: int = None): """ Set random initial state for the simulation. Parameters ---------- seed : int, optional Random seed for reproducible initial conditions. If None, uses the seed specified during initialization. """ seed_ = self.seed if seed is None else seed self.P.initial_state = set_initial_state_jr(self.P.nn, seed_)
[docs] def run(self, par: dict = None, x0: np.ndarray = None): """ Run the Jansen-Rit simulation. Parameters ---------- par : dict, optional Dictionary of parameters to update for this simulation run. Any parameter from the class documentation can be updated. x0 : np.ndarray, optional Initial state vector of shape (6*nn,). If None, uses the initial state set during initialization or by set_initial_state(). Returns ------- dict Dictionary containing simulation results: - 't': np.ndarray of shape (n_steps,) - time points - 'x': np.ndarray of shape (n_steps, nn) - simulated time series (y - z) representing local field potentials """ # Optionally update parameters on the jitclass (Numba allows setattr) if par: for k, v in par.items(): if k == "weights": W = np.array(v, dtype=np.float64) if W.ndim != 2 or W.shape[0] != W.shape[1]: raise ValueError("'weights' must be a square 2D array") setattr(self.P, "weights", W) setattr(self.P, "nn", len(W)) elif k in ("C0", "C1", "C2", "C3"): arr = _as_1d_array_like(v, self.P.nn) setattr(self.P, k, arr) elif hasattr(self.P, k): setattr(self.P, k, v) else: raise ValueError(f"Invalid parameter: {k}") # Optionally replace initial state if x0 is not None: x0 = np.array(x0, dtype=np.float64) if x0.shape[0] != 6 * self.P.nn: raise ValueError("initial_state must have length 6*nn") self.P.initial_state = x0 if not self._checked: self.check_input() return integrate(self.P)
# Alias for consistency with naming convention JR_sde_numba = JR_sde