Source code for vbi.models.numba.ww

# gpt5

import warnings
import numpy as np
from copy import copy
from numba import njit, jit
from numba.experimental import jitclass
from numba.extending import register_jitable
from vbi.utils import print_valid_parameters
from numba import float64, boolean, int64, types
from numba.core.errors import NumbaPerformanceWarning

warnings.simplefilter("ignore", category=NumbaPerformanceWarning)


# -----------------------------
# Helper utilities
# -----------------------------

# @register_jitable
# def set_seed_compat(x):
#     np.random.seed(x)

[docs] @jit(nopython=True) def initialize_random_state(seed): """ Initialize the random number generator with a specific seed. This function sets the random seed in the Numba context to ensure reproducible stochastic simulations. Parameters ---------- seed : int Random seed value for reproducibility """ np.random.seed(seed)
[docs] def check_vec_size_1d(x, nn): """ Return a 1D vector of size nn, broadcasting scalar if needed. This utility function ensures that parameter inputs are properly formatted as arrays of the correct size for multi-region simulations. Parameters ---------- x : scalar or array-like Input value(s) to be broadcast or validated nn : int Required array size (number of brain regions) Returns ------- np.ndarray Array of shape (nn,) with input values properly broadcast """ x = np.array(x, dtype=np.float64) if np.ndim(x) > 0 else np.array([x], dtype=np.float64) return np.ones(nn, dtype=np.float64) * x if x.size != nn else x.astype(np.float64)
# ----------------------------- # BOLD model parameters (same structure as in mpr.py) # ----------------------------- bold_spec = [ ("kappa", float64), ("gamma", float64), ("tau", float64), ("alpha", float64), ("epsilon", float64), ("Eo", float64), ("TE", float64), ("vo", float64), ("r0", float64), ("theta0", float64), ("t_min", float64), ("rtol", float64), ("atol", float64), ]
[docs] @jitclass(bold_spec) class ParBold: """ Parameter class for BOLD signal generation in the Wong-Wang model. This Numba jitclass holds parameters for the hemodynamic response model that converts neural activity to BOLD signal. Based on the Balloon-Windkessel model for simulating fMRI BOLD responses. Parameters ---------- kappa : float, default 0.65 Signal decay parameter gamma : float, default 0.41 Feedback regulation parameter tau : float, default 0.98 Hemodynamic transit time (s) alpha : float, default 0.32 Grubb's vessel stiffness exponent epsilon : float, default 0.34 Efficacy of oxygen extraction Eo : float, default 0.4 Oxygen extraction fraction at rest TE : float, default 0.04 Echo time (s) vo : float, default 0.08 Resting venous volume fraction r0 : float, default 25.0 Slope parameter for intravascular signal theta0 : float, default 40.3 Frequency offset at the outer surface of magnetized vessels (Hz) t_min : float, default 0.0 Minimum integration time rtol : float, default 1e-5 Relative tolerance for integration atol : float, default 1e-8 Absolute tolerance for integration """ def __init__( self, kappa=0.65, gamma=0.41, tau=0.98, alpha=0.32, epsilon=0.34, Eo=0.4, TE=0.04, vo=0.08, r0=25.0, theta0=40.3, t_min=0.0, rtol=1e-5, atol=1e-8, ): self.kappa = kappa self.gamma = gamma self.tau = tau self.alpha = alpha self.epsilon = epsilon self.Eo = Eo self.TE = TE self.vo = vo self.r0 = r0 self.theta0 = theta0 self.t_min = t_min self.rtol = rtol self.atol = atol
[docs] @jit(nopython=True) def do_bold_step(r_in, s, f, ftilde, vtilde, qtilde, v, q, dtt, P): """ Perform one integration step of the BOLD hemodynamic response model. Implements the Balloon-Windkessel model to convert neural activity into BOLD signal by simulating the hemodynamic cascade: neural activity → vasodilatory signal → blood flow → blood volume → deoxyhemoglobin → BOLD signal. Parameters ---------- r_in : np.ndarray Neural activity input (excitatory synaptic gating variables) s : np.ndarray, shape (2, nn) Vasodilatory signal state variables f : np.ndarray, shape (2, nn) Blood flow state variables ftilde : np.ndarray, shape (2, nn) Log-transformed blood flow variables vtilde : np.ndarray, shape (2, nn) Log-transformed blood volume variables qtilde : np.ndarray, shape (2, nn) Log-transformed deoxyhemoglobin content variables v : np.ndarray, shape (2, nn) Blood volume state variables q : np.ndarray, shape (2, nn) Deoxyhemoglobin content state variables dtt : float Integration time step (seconds) P : ParBold BOLD parameter object Notes ----- This function modifies the state arrays in-place and uses a log-transform approach to ensure numerical stability of the hemodynamic state variables. """ kappa = P.kappa gamma = P.gamma ialpha = 1.0 / P.alpha tau = P.tau Eo = P.Eo s[1] = s[0] + dtt * (r_in - kappa * s[0] - gamma * (f[0] - 1.0)) # keep f[0] >= 1 to avoid log issues f[0] = np.maximum(f[0], 1.0) ftilde[1] = ftilde[0] + dtt * (s[0] / f[0]) fv = v[0] ** ialpha # outflow vtilde[1] = vtilde[0] + dtt * ((f[0] - fv) / (tau * v[0])) q[0] = np.maximum(q[0], 0.01) ff = (1.0 - (1.0 - Eo) ** (1.0 / f[0])) / Eo # oxygen extraction qtilde[1] = qtilde[0] + dtt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0])) # exponentiate back f[1] = np.exp(ftilde[1]) v[1] = np.exp(vtilde[1]) q[1] = np.exp(qtilde[1]) # roll state f[0] = f[1] s[0] = s[1] ftilde[0] = ftilde[1] vtilde[0] = vtilde[1] qtilde[0] = qtilde[1] v[0] = v[1] q[0] = q[1]
# ----------------------------- # Wong–Wang model params (Numba jitclass) # ----------------------------- ww_spec = [ # local population parameters ("a_exc", float64), ("a_inh", float64), ("b_exc", float64), ("b_inh", float64), ("d_exc", float64), ("d_inh", float64), ("tau_exc", float64), ("tau_inh", float64), ("gamma_exc", float64), ("gamma_inh", float64), ("W_exc", float64), ("W_inh", float64), ("ext_current", float64[:]), ("J_NMDA", float64), ("J_I", float64), ("w_plus", float64), ("lambda_inh_exc", float64), # global / simulation parameters ("t_end", float64), ("t_cut", float64), ("dt", float64), ("G_exc", float64), ("G_inh", float64), ("weights", float64[:, :]), ("tr", float64), ("s_decimate", int64), ("sigma", float64), ("nn", int64), ("seed", int64), ("output", types.string), ("dtype", types.string), ("initial_state", float64[:]), ("RECORD_S", boolean), ("RECORD_BOLD", boolean), ]
[docs] @jitclass(ww_spec) class ParWW: """ Parameter class for the Wong-Wang full neural mass model. This Numba jitclass holds all parameters required for Wong-Wang simulation including biophysical parameters for excitatory and inhibitory populations, coupling strengths, noise levels, and simulation settings. Uses Numba for high-performance compilation. Parameters ---------- a_exc : float, default 310.0 Excitatory population gain parameter (n/C) a_inh : float, default 0.615 Inhibitory population gain parameter (nC⁻¹) b_exc : float, default 125.0 Excitatory population threshold parameter (Hz) b_inh : float, default 177.0 Inhibitory population threshold parameter (Hz) d_exc : float, default 0.16 Excitatory population saturation parameter (s) d_inh : float, default 0.087 Inhibitory population saturation parameter (s) tau_exc : float, default 100.0 Excitatory synaptic time constant (ms) tau_inh : float, default 10.0 Inhibitory synaptic time constant (ms) gamma_exc : float, default 0.641/1000 Excitatory kinetic parameter (ms⁻¹) gamma_inh : float, default 1.0/1000 Inhibitory kinetic parameter (ms⁻¹) W_exc : float, default 1.0 Excitatory population local weight W_inh : float, default 0.7 Inhibitory population local weight ext_current : array-like, default [0.382] External current input per region (nA) J_NMDA : float, default 0.15 NMDA synaptic coupling strength (nA) J_I : float, default 1.0 Inhibitory synaptic coupling strength (nA) w_plus : float, default 1.4 Local excitatory recurrence strength lambda_inh_exc : float, default 0.0 Long-range feedforward inhibition switch """ def __init__( self, # exc/inh params (Wong & Wang 2006 / Deco et al.) a_exc=310.0, a_inh=0.615, b_exc=125.0, b_inh=177.0, d_exc=0.16, d_inh=0.087, tau_exc=100.0, # ms tau_inh=10.0, # ms gamma_exc=0.641 / 1000.0, gamma_inh=1.0 / 1000.0, W_exc=1.0, W_inh=0.7, ext_current=np.array([0.382]), # nA J_NMDA=0.15, J_I=1.0, w_plus=1.4, lambda_inh_exc=0.0, # simulation t_end=1000.0, t_cut=0.0, dt=0.1, G_exc=0.0, G_inh=0.0, weights=np.array([[], []]), tr=300.0, # ms s_decimate=1, sigma=0.0, nn=1, seed=-1, output="output", dtype="f", initial_state=np.array([0.0]), RECORD_S=False, RECORD_BOLD=True, ): # assign self.a_exc = a_exc self.a_inh = a_inh self.b_exc = b_exc self.b_inh = b_inh self.d_exc = d_exc self.d_inh = d_inh self.tau_exc = tau_exc self.tau_inh = tau_inh self.gamma_exc = gamma_exc self.gamma_inh = gamma_inh self.W_exc = W_exc self.W_inh = W_inh self.ext_current = ext_current self.J_NMDA = J_NMDA self.J_I = J_I self.w_plus = w_plus self.lambda_inh_exc = lambda_inh_exc self.t_end = t_end self.t_cut = t_cut self.dt = dt self.G_exc = G_exc self.G_inh = G_inh self.weights = weights self.tr = tr self.s_decimate = s_decimate self.sigma = sigma self.nn = nn self.seed = seed self.output = output self.dtype = dtype self.initial_state = initial_state self.RECORD_S = RECORD_S self.RECORD_BOLD = RECORD_BOLD
# ----------------------------- # Wong–Wang dynamics (Numba) # -----------------------------
[docs] @njit def firing_rate(current, a, b, d): """ Compute the firing rate using the Wong-Wang transfer function. Implements the input-output relationship for neural populations: r(I) = (a*I - b) / (1 - exp(-d*(a*I - b))) This function captures the nonlinear relationship between synaptic current and population firing rate, including saturation effects. Parameters ---------- current : np.ndarray Synaptic current input to the population. a : float Gain parameter controlling the slope of the transfer function. b : float Threshold parameter determining the firing threshold. d : float Saturation parameter controlling the saturation behavior. Returns ------- np.ndarray Population firing rate (Hz). """ u = a * current - b den = 1.0 - np.exp(-d * u) # avoid division by ~0; if u ~ 0 => limit is a/d out = np.zeros_like(current) for i in range(current.shape[0]): if np.abs(den[i]) < 1e-12: out[i] = a * u[i] * 0.5 # very small; fallback (won't really occur) else: out[i] = u[i] / den[i] return out
[docs] @njit def f_ww(S, t, P): """ Compute the right-hand side of the Wong-Wang full model equations. This function implements the deterministic part of the Wong-Wang model, computing the time derivatives of synaptic gating variables for both excitatory and inhibitory populations across all brain regions. The implemented equations are: dS_exc/dt = -S_exc/τ_exc + (1-S_exc)*γ_exc*r_exc(I_exc) dS_inh/dt = -S_inh/τ_inh + γ_inh*r_inh(I_inh) Parameters ---------- S : np.ndarray Current state vector of shape (2*nn,) containing stacked arrays: [S_exc_0, ..., S_exc_n, S_inh_0, ..., S_inh_n] where nn is the number of brain regions. t : float Current time (not used in autonomous system). P : ParWW Parameter object containing all model parameters. Returns ------- np.ndarray Derivative vector dS/dt of shape (2*nn,). """ nn = P.nn S_exc = S[:nn] S_inh = S[nn:] # network couplings network_exc_exc = P.weights.dot(S_exc) if P.lambda_inh_exc > 0.0: network_inh_exc = P.weights.dot(S_inh) else: network_inh_exc = np.zeros_like(S_exc) # currents current_exc = ( P.W_exc * P.ext_current + P.w_plus * P.J_NMDA * S_exc + P.G_exc * P.J_NMDA * network_exc_exc - P.J_I * S_inh ) current_inh = ( P.W_inh * P.ext_current + P.J_NMDA * S_inh - S_inh + P.G_inh * P.J_NMDA * network_inh_exc ) # firing rates r_exc = firing_rate(current_exc, P.a_exc, P.b_exc, P.d_exc) r_inh = firing_rate(current_inh, P.a_inh, P.b_inh, P.d_inh) dSdt = np.zeros(2 * nn) # exc dSdt[:nn] = (-S_exc / P.tau_exc) + (1.0 - S_exc) * P.gamma_exc * r_exc # inh dSdt[nn:] = (-S_inh / P.tau_inh) + P.gamma_inh * r_inh return dSdt
[docs] @jit(nopython=True) def heun_sde(S, t, P): """ Perform one heun integration step for the stochastic Wong-Wang model. This function implements the Heun method (predictor-corrector) for numerical integration of stochastic differential equations, providing second-order accuracy for the deterministic part while properly handling additive Gaussian noise. Parameters ---------- S : np.ndarray Current state vector of shape (2*nn,) containing synaptic gating variables. t : float Current time (ms). P : ParWW Parameter object containing model parameters including dt and sigma. Returns ------- np.ndarray Updated state vector after one integration step. """ dt = P.dt nn = P.nn dW = P.sigma * np.sqrt(dt) * np.random.randn(2 * nn) k1 = f_ww(S, t, P) y_ = S + dt * k1 + dW k2 = f_ww(y_, t + dt, P) S = S + 0.5 * dt * (k1 + k2) + dW return S
# ----------------------------- # Public-facing class (mirror mpr.py style) # -----------------------------
[docs] class WW_sde: """ Numba implementation of the Wong-Wang full neural mass model with stochastic dynamics. The Wong-Wang full model is a biophysically realistic neural mass model that explicitly captures the dynamics of both excitatory and inhibitory neural populations. This model is based on the original work of Wong and Wang [Wong2006]_ and has been extended for whole-brain network simulations [Deco2013]_, [Deco2014]_. The model provides a detailed representation of recurrent network mechanisms underlying decision-making processes and has been widely used to study brain dynamics in health and disease. The model describes the temporal evolution of synaptic gating variables for excitatory (S_exc) and inhibitory (S_inh) populations at each brain region. The dynamics are governed by the balance between synaptic decay, activity-dependent facilitation, and network coupling. The firing rates of each population are determined by input-output transfer functions that capture the relationship between synaptic currents and population firing rates. The Wong-Wang full model equations are: .. math:: \\frac{dS_{exc,i}}{dt} &= -\\frac{S_{exc,i}}{\\tau_{exc}} + (1 - S_{exc,i}) \\gamma_{exc} r_{exc,i}(t) + \\sigma \\xi_i(t) \\\\ \\frac{dS_{inh,i}}{dt} &= -\\frac{S_{inh,i}}{\\tau_{inh}} + \\gamma_{inh} r_{inh,i}(t) + \\sigma \\xi_i(t) where the firing rates are computed using: .. math:: r_{exc,i}(t) &= \\frac{a_{exc} I_{exc,i} - b_{exc}}{1 - \\exp(-d_{exc}(a_{exc} I_{exc,i} - b_{exc}))} \\\\ r_{inh,i}(t) &= \\frac{a_{inh} I_{inh,i} - b_{inh}}{1 - \\exp(-d_{inh}(a_{inh} I_{inh,i} - b_{inh}))} The total synaptic currents for each population are: .. math:: I_{exc,i} &= W_{exc} I_{ext} + w_{plus} J_{NMDA} S_{exc,i} + G_{exc} J_{NMDA} \\sum_{j=1}^{N} SC_{ij} S_{exc,j} - J_I S_{inh,i} \\\\ I_{inh,i} &= W_{inh} I_{ext} + J_{NMDA} S_{exc,i} - S_{inh,i} + G_{inh} J_{NMDA} \\lambda_{inh,exc} \\sum_{j=1}^{N} SC_{ij} S_{inh,j} .. list-table:: Parameters :widths: 25 50 25 :header-rows: 1 * - Name - Explanation - Default Value * - `a_exc` - Excitatory population gain parameter (n/C) - 310.0 * - `a_inh` - Inhibitory population gain parameter (nC⁻¹) - 0.615 * - `b_exc` - Excitatory population threshold parameter (Hz) - 125.0 * - `b_inh` - Inhibitory population threshold parameter (Hz) - 177.0 * - `d_exc` - Excitatory population saturation parameter (s) - 0.16 * - `d_inh` - Inhibitory population saturation parameter (s) - 0.087 * - `tau_exc` - Excitatory synaptic time constant (ms) - 100.0 * - `tau_inh` - Inhibitory synaptic time constant (ms) - 10.0 * - `gamma_exc` - Excitatory kinetic parameter (ms⁻¹) - 0.641/1000 * - `gamma_inh` - Inhibitory kinetic parameter (ms⁻¹) - 1.0/1000 * - `W_exc` - Excitatory population local weight - 1.0 * - `W_inh` - Inhibitory population local weight - 0.7 * - `ext_current` - External current input (nA) - 0.382 * - `J_NMDA` - NMDA synaptic coupling strength (nA) - 0.15 * - `J_I` - Inhibitory synaptic coupling strength (nA) - 1.0 * - `w_plus` - Local excitatory recurrence strength - 1.4 * - `lambda_inh_exc` - Long-range feedforward inhibition switch - 0.0 * - `G_exc` - Global excitatory coupling strength - 0.0 * - `G_inh` - Global inhibitory coupling strength - 0.0 * - `sigma` - Noise amplitude - 0.0 * - `weights` - Structural connectivity matrix (nn x nn) - zeros * - `dt` - Integration time step (ms) - 0.1 * - `t_end` - Simulation end time (ms) - 1000.0 * - `t_cut` - Initial time to discard (ms) - 0.0 Usage example: >>> import numpy as np >>> from vbi.models.numba.ww import WW_sde >>> W = np.eye(2) * 0.1 # 2-node connectivity >>> I_ext = np.array([0.4, 0.5]) # External currents >>> ww = WW_sde({ ... "weights": W, ... "ext_current": I_ext, ... "G_exc": 0.5, ... "G_inh": 0.2, ... "dt": 0.1, ... "t_end": 1000.0, ... "t_cut": 100.0, ... "sigma": 0.01 ... }) >>> data = ww.run() >>> print(data['bold_d'].shape) # BOLD data shape (2, 2) >>> print(ww.P.tr) # Repetition time 300.0 Notes ----- The Wong-Wang model is particularly useful for studying: - Decision-making processes and attractor dynamics - Excitatory-inhibitory balance in cortical circuits - Biophysically realistic neural mass dynamics - Working memory mechanisms - Brain state transitions and criticality The model's strength lies in its biophysical realism while maintaining computational efficiency through mean-field approximations. References ---------- .. [Wong2006] Wong, K. F., & Wang, X. J. (2006). A recurrent network mechanism of time integration in perceptual decisions. Journal of Neuroscience, 26(4), 1314-1328. .. [Deco2013] Deco, G., Ponce-Alvarez, A., Mantini, D., Romani, G. L., Hagmann, P., & Corbetta, M. (2013). Resting-state functional connectivity emerges from structurally and dynamically shaped slow linear fluctuations. Journal of Neuroscience, 33(27), 11239-11252. .. [Deco2014] Deco, G., Ponce-Alvarez, A., Hagmann, P., Romani, G. L., Mantini, D., & Corbetta, M. (2014). How local excitation–inhibition ratio impacts the whole brain dynamics. Journal of Neuroscience, 34(23), 7886-7898. """
[docs] def __init__(self, par: dict = None, Bpar: dict = None) -> None: if par is None: par = {} if Bpar is None: Bpar = {} # sanity & defaults nn = par.get("nn", None) weights = par.get("weights", None) if weights is None: # default single node weights = np.zeros((1, 1), dtype=np.float64) weights = np.array(weights, dtype=np.float64) if nn is None: nn = weights.shape[0] par.setdefault("nn", nn) # broadcast scalars to vectors where necessary par.setdefault("ext_current", 0.382) par["ext_current"] = check_vec_size_1d(par["ext_current"], nn) # dt-based noise scalars are computed inside heun_sde (uses sigma directly) # initial state if "initial_state" in par: par["initial_state"] = np.array(par["initial_state"], dtype=np.float64) else: par["initial_state"] = set_initial_state(nn, par.get("seed", -1)) par.setdefault("dtype", "f") # kept for compatibility par.setdefault("output", "output") # create numba jitclass param holders self.P = ParWW( a_exc=par.get("a_exc", 310.0), a_inh=par.get("a_inh", 0.615), b_exc=par.get("b_exc", 125.0), b_inh=par.get("b_inh", 177.0), d_exc=par.get("d_exc", 0.16), d_inh=par.get("d_inh", 0.087), tau_exc=par.get("tau_exc", 100.0), tau_inh=par.get("tau_inh", 10.0), gamma_exc=par.get("gamma_exc", 0.641 / 1000.0), gamma_inh=par.get("gamma_inh", 1.0 / 1000.0), W_exc=par.get("W_exc", 1.0), W_inh=par.get("W_inh", 0.7), ext_current=par["ext_current"].astype(np.float64), J_NMDA=par.get("J_NMDA", 0.15), J_I=par.get("J_I", 1.0), w_plus=par.get("w_plus", 1.4), lambda_inh_exc=par.get("lambda_inh_exc", 0.0), t_end=par.get("t_end", 1000.0), t_cut=par.get("t_cut", 0.0), dt=par.get("dt", 0.1), G_exc=par.get("G_exc", 0.0), G_inh=par.get("G_inh", 0.0), weights=weights, tr=par.get("tr", 300.0), s_decimate=int(par.get("s_decimate", 1)), sigma=par.get("sigma", 0.0), nn=nn, seed=int(par.get("seed", -1)), output=par.get("output", "output"), dtype=par.get("dtype", "f"), initial_state=par["initial_state"], RECORD_S=bool(par.get("RECORD_S", False)), RECORD_BOLD=bool(par.get("RECORD_BOLD", True)), ) # Bold parameters self.B = ParBold( kappa=Bpar.get("kappa", 0.65), gamma=Bpar.get("gamma", 0.41), tau=Bpar.get("tau", 0.98), alpha=Bpar.get("alpha", 0.32), epsilon=Bpar.get("epsilon", 0.34), Eo=Bpar.get("Eo", 0.4), TE=Bpar.get("TE", 0.04), vo=Bpar.get("vo", 0.08), r0=Bpar.get("r0", 25.0), theta0=Bpar.get("theta0", 40.3), t_min=Bpar.get("t_min", 0.0), rtol=Bpar.get("rtol", 1e-5), atol=Bpar.get("atol", 1e-8), ) # seeding if self.P.seed >= 0: # set_seed_compat(self.P.seed) initialize_random_state(self.P.seed)
# print(f"WW_sde: setting random seed to {self.P.seed}") def __str__(self) -> str: lines = [ "Wong-Wang (Numba) model", "Parameters: --------------------------------", ] for name in [ "nn", "dt", "t_end", "t_cut", "G_exc", "G_inh", "sigma", "tr", "a_exc", "b_exc", "d_exc", "tau_exc", "gamma_exc", "a_inh", "b_inh", "d_inh", "tau_inh", "gamma_inh", "W_exc", "W_inh", "w_plus", "J_NMDA", "J_I", ]: lines.append(f"{name} = {getattr(self.P, name)}") lines.append("--------------------------------------------") return "\n".join(lines)
[docs] def check_parameters(self, par: dict) -> None: """ Validate that all provided parameters are recognized. Parameters ---------- par : dict Dictionary of parameters to validate. Raises ------ ValueError If any parameter name is not recognized. """ for key in par.keys(): if key not in self.valid_par: print(f"Invalid parameter: {key}") print_valid_parameters(ww_spec) raise ValueError(f"Invalid parameter: {key}")
# ----------------------------- # Simulation # -----------------------------
[docs] def run(self, par: dict = None, x0=None, verbose=True): """ Run the Wong-Wang full model simulation. Executes the Euler-Maruyama numerical integration scheme to simulate the stochastic Wong-Wang dynamics. Includes optional BOLD signal generation and data downsampling. Parameters ---------- par : dict, optional Dictionary of parameters to update before simulation. Can include any parameter from the ParWW class. x0 : array-like, optional Initial state vector of length 2*nn (S_exc and S_inh for each region). If None, uses the stored initial_state from the parameter object. verbose : bool, default True Whether to print simulation progress and parameter information. Returns ------- dict Dictionary containing simulation results with keys: - 'S' : ndarray, shape (T, nn) Recorded excitatory synaptic gating variables if RECORD_S=True - 't' : ndarray, shape (T,) Time points for neural data (ms) - 'bold_t' : ndarray, shape (T_bold,) Time points for BOLD signal (ms) - 'bold_d' : ndarray, shape (T_bold, nn) BOLD signal time series if RECORD_BOLD=True Notes ----- The simulation uses the Heun-Euler method for stochastic integration, which provides second-order accuracy for the deterministic part while handling the stochastic noise appropriately. """ self.valid_par = [ww_spec[i][0] for i in range(len(ww_spec))] self.check_parameters(par) # update runtime parameters if provided if par: for key, val in par.items(): if key == "ext_current": val = check_vec_size_1d(val, self.P.nn).astype(np.float64) setattr(self.P, key, val) # initial state if x0 is None: S = copy(self.P.initial_state) else: S = np.array(x0, dtype=np.float64) # sanity assert self.P.weights is not None assert self.P.weights.shape[0] == self.P.weights.shape[1] assert len(S) == 2 * self.P.nn, "x0 must be length 2*nn" assert self.P.t_cut < self.P.t_end # time grid nt = int(np.floor(self.P.t_end / self.P.dt)) t = np.arange(nt) * self.P.dt valid_mask = t > self.P.t_cut s_buffer_len = int(np.sum(valid_mask) // max(1, self.P.s_decimate)) # buffers t_buf = np.zeros((s_buffer_len,), dtype=np.float32) S_rec = np.zeros((s_buffer_len, self.P.nn), dtype=np.float32) if self.P.RECORD_S else np.array([]) # BOLD buffers tr = self.P.tr bold_decimate = int(np.round(tr / self.P.dt)) dtt = self.P.dt / 1000.0 # seconds s = np.zeros((2, self.P.nn)) f = np.zeros((2, self.P.nn)) ftilde = np.zeros((2, self.P.nn)) vtilde = np.zeros((2, self.P.nn)) qtilde = np.zeros((2, self.P.nn)) v = np.zeros((2, self.P.nn)) q = np.zeros((2, self.P.nn)) vv = np.zeros((nt // max(1, bold_decimate), self.P.nn), dtype=np.float64) qq = np.zeros_like(vv) # init BOLD states s[0] = 1.0 f[0] = 1.0 v[0] = 1.0 q[0] = 1.0 ftilde[0] = 0.0 vtilde[0] = 0.0 qtilde[0] = 0.0 # main loop s_idx = 0 for i in range(nt): t_curr = i * self.P.dt S = heun_sde(S, t_curr, self.P) if (t_curr > self.P.t_cut) and (i % max(1, self.P.s_decimate) == 0): if s_idx < s_buffer_len: t_buf[s_idx] = t_curr if self.P.RECORD_S: S_rec[s_idx] = S[: self.P.nn].astype(np.float32) s_idx += 1 if self.P.RECORD_BOLD: do_bold_step(S[: self.P.nn], s, f, ftilde, vtilde, qtilde, v, q, dtt, self.B) if (i % max(1, bold_decimate) == 0) and ((i // max(1, bold_decimate)) < vv.shape[0]): vv[i // max(1, bold_decimate)] = v[1] qq[i // max(1, bold_decimate)] = q[1] # finalize BOLD bold_t = np.linspace(0.0, self.P.t_end - self.P.dt * max(1, bold_decimate), vv.shape[0]) if self.P.RECORD_BOLD: # cut off t <= t_cut valid = bold_t > self.P.t_cut bold_t = bold_t[valid] if bold_t.size > 0: vv = vv[valid] qq = qq[valid] k1 = 4.3 * self.B.theta0 * self.B.Eo * self.B.TE k2 = self.B.epsilon * self.B.r0 * self.B.Eo * self.B.TE k3 = 1.0 - self.B.epsilon bold_d = self.B.vo * (k1 * (1.0 - qq) + k2 * (1.0 - qq / vv) + k3 * (1.0 - vv)) else: bold_d = np.array([]) else: bold_d = np.array([]) bold_t = np.array([]) return { "S": S_rec, "t": t_buf, "bold_t": bold_t.astype(np.float32), "bold_d": bold_d.astype(np.float32), }
# ----------------------------- # API helpers # -----------------------------
[docs] def set_initial_state(nn, seed=-1): """ Generate random initial conditions for the Wong-Wang model. Creates small positive random values for both excitatory and inhibitory synaptic gating variables, ensuring the system starts in a biologically plausible state. Parameters ---------- nn : int Number of brain regions/nodes. seed : int, optional Random seed for reproducibility. If -1 or None, no seeding is applied. Returns ------- np.ndarray Initial state vector of shape (2*nn,) containing small positive random values for [S_exc_0, ..., S_exc_n, S_inh_0, ..., S_inh_n]. """ if seed is not None and seed >= 0: np.random.seed(seed) # initialize_random_state(seed) y0 = np.random.rand(2 * nn) * 0.1 # small positive return y0.astype(np.float64)