import tqdm
import numpy as np
from vbi.models.cupy.utils import *
[docs]
class JR_sde:
"""
Jansen-Rit model cupy implementation.
.. list-table:: Parameters
:widths: 25 50 25
:header-rows: 1
* - Name
- Explanation
- Default Value
* - `A`
- Excitatory post synaptic potential amplitude.
- 3.25
* - `B`
- Inhibitory post synaptic potential amplitude.
- 22.0
* - `1/a`
- Time constant of the excitatory postsynaptic potential.
- a: 0.1 (1/a: 10.0)
* - `1/b`
- Time constant of the inhibitory postsynaptic potential.
- b: 0.05 (1/b: 20.0)
* - `C0`, `C1`
- Average numbers of synapses between excitatory populations. If array-like, it should be the shape of (`num_nodes`, `num_sim`).
- C0: 1.0 * 135.0, C1: 0.8 * 135.0
* - `C2`, `C3`
- Average numbers of synapses between inhibitory populations. If array-like, it should be the shape of (`num_nodes`, `num_sim`).
- C2: 0.25 * 135.0, C3: 0.25 * 135.0
* - `vmax`
- Maximum firing rate
- 0.005
* - `v`
- Potential at half of maximum firing rate
- 6.0
* - `r`
- Slope of sigmoid function at `v_0`
- 0.56
* - `G`
- Scaling the strength of network connections. If array-like, it should of length `num_sim`.
- 1.0
* - `mu`
- Mean of the noise
- 0.24
* - `noise_amp`
- Amplitude of the noise
- 0.01
* - `weights`
- Weight matrix of shape (`num_nodes`, `num_nodes`)
- None
* - `num_sim`
- Number of simulations
- 1
* - `dt`
- Time step
- 0.01
* - `t_end`
- End time of simulation
- 1000.0
* - `t_cut`
- Cut time
- 500.0
* - `engine`
- "cpu" or "gpu"
- "cpu"
* - `method`
- "heun" or "euler" method for integration
- "heun"
* - `seed`
- Random seed
- None
* - `initial_state`
- Initial state of the system of shape (`num_nodes`, `num_sim`)
- None
* - `same_initial_state`
- If True, all simulations have the same initial state
- False
* - `same_noise_per_sim`
- If True, all simulations have the same noise
- False
* - `decimate`
- Decimation factor for the output time series
- 1
* - `dtype`
- Data type to use for the simulation, `float` for `float64` or `f` for `float32`.
- "float"
"""
[docs]
def __init__(self, par: dict = {}):
self.valid_parameters = list(self.get_default_parameters().keys())
self.check_parameters(par)
self._par = self.get_default_parameters()
self._par.update(par)
for item in self._par.items():
name = item[0]
value = item[1]
setattr(self, name, value)
self.xp = get_module(self.engine)
if self.seed is not None:
self.xp.random.seed(self.seed)
def __str__(self) -> str:
print("Jansen-Rit Model")
print("----------------")
for item in self._par.items():
name = item[0]
value = item[1]
print(f"{name} = {value}")
return ""
def __call__(self):
print("Jansen-Rit Model")
return self._par
[docs]
def check_parameters(self, par):
"""
Check if the provided parameters are valid.
Parameters
----------
par : dict
Dictionary of parameters to check.
Raises
------
ValueError
If any parameter in `par` is not valid.
"""
for key in par.keys():
if key not in self.valid_parameters:
raise ValueError("Invalid parameter: " + key)
[docs]
def set_initial_state(self):
self.initial_state = set_initial_state(
self.nn,
self.num_sim,
self.engine,
self.seed,
self.same_initial_state,
self.dtype,
)
[docs]
def get_default_parameters(self) -> dict:
"""
Default parameters for the Jansen-Rit model
Parameters
----------
nn : int
number of nodes
Returns
-------
params : dict
default parameters
"""
params = {
"G": 1.0,
"A": 3.25,
"B": 22.0,
"v": 6.0,
"r": 0.56,
"v0": 6.0,
"vmax": 0.005,
"C0": 1.0 * 135.0,
"C1": 0.8 * 135.0,
"C2": 0.25 * 135.0,
"C3": 0.25 * 135.0,
"a": 0.1,
"b": 0.05,
"mu": 0.24,
"noise_amp": 0.01,
"decimate": 1,
"dt": 0.01,
"t_end": 1000.0,
"t_cut": 500.0,
"engine": "cpu",
"method": "heun",
"num_sim": 1,
"weights": None,
"dtype": "float",
"seed": None,
"initial_state": None,
"same_initial_state": False,
"same_noise_per_sim": False,
}
return params
[docs]
def S_(self, x):
"""
Compute the sigmoid function.
This function calculates the sigmoid of the input `x` using the parameters
`vmax`, `r`, and `v0`.
Parameters
----------
x : float or array-like
The input value(s) for which to compute the sigmoid function.
Returns
-------
float or array-like
The computed sigmoid value(s).
"""
return self.vmax / (1.0 + self.xp.exp(self.r * (self.v0 - x)))
[docs]
def f_sys(self, x0, t):
"""
Compute the derivatives of the Jansen-Rit neural mass model.
Parameters
----------
x0 : array_like
Initial state vector of the system. It should have a shape of (6*nn, ns), where nn is the number of neurons and ns is the number of simulations.
t : float
Current time point (not used in the computation but required for compatibility with ODE solvers).
Returns
-------
dx : array_like
Derivatives of the state vector. It has the same shape as `x0`.
The function computes the derivatives of the state vector based on the Jansen-Rit model equations.
"""
nn = self.nn
ns = self.num_sim
mu = self.mu
G = self.G
C0 = self.C0
C1 = self.C1
C2 = self.C2
C3 = self.C3
A = self.A
B = self.B
a = self.a
b = self.b
Aa = A * a
Bb = B * b
bb = b * b
aa = a * a
SC = self.weights
_xp = self.xp
S = self.S_
x = x0[:nn, :]
y = x0[nn : 2 * nn, :]
z = x0[2 * nn : 3 * nn, :]
xp = x0[3 * nn : 4 * nn, :]
yp = x0[4 * nn : 5 * nn, :]
zp = x0[5 * nn : 6 * nn, :]
dx = _xp.zeros((6 * nn, ns))
couplings = S(SC.dot(y - z))
dx[0:nn, :] = xp
dx[nn : 2 * nn, :] = yp
dx[2 * nn : 3 * nn, :] = zp
dx[3 * nn : 4 * nn, :] = Aa * S(y - z) - 2 * a * xp - aa * x
dx[4 * nn : 5 * nn, :] = (
Aa * (mu + C1 * S(C0 * x) + G * couplings) - 2 * a * yp - aa * y
)
dx[5 * nn : 6 * nn, :] = Bb * C3 * S(C2 * x) - 2 * b * zp - bb * z
return dx
[docs]
def euler(self, x0, t):
"""
Perform one step of the Euler-Maruyama method for stochastic differential equations.
Parameters
----------
x0 : array_like
The initial state of the system.
t : float
The current time.
Returns
-------
array_like
The updated state of the system after one Euler step.
"""
_xp = self.xp
nn = self.nn
ns = self.num_sim
dt = self.dt
sqrt_dt = np.sqrt(dt)
noise_amp = self.noise_amp
randn = _xp.random.randn
snps = self.same_noise_per_sim
dW = randn(nn, 1) if snps else randn(nn, ns)
dW = sqrt_dt * noise_amp * dW
x0 = x0 + dt * self.f_sys(x0, t)
x0[4 * nn : 5 * nn, :] += dW
return x0
[docs]
def heun(self, x0, t):
"""
Perform a single step of the Heun's method for stochastic differential equations.
Parameters
----------
x0 : ndarray
The initial state of the system.
t : float
The current time.
Returns
-------
ndarray
The updated state of the system after one Heun step.
"""
nn = self.nn
ns = self.num_sim
dt = self.dt
_xp = self.xp
sqrt_dt = np.sqrt(dt)
noise_amp = self.noise_amp
randn = _xp.random.randn
snps = self.same_noise_per_sim
dW = randn(nn, 1) if snps else randn(nn, ns)
if snps:
dW = np.tile(dW, (1, ns))
dW = sqrt_dt * noise_amp * dW
k1 = self.f_sys(x0, t) * dt
x1 = x0 + k1
x1[4 * nn : 5 * nn, :] += dW
k2 = self.f_sys(x1, t + dt) * dt
x0 = x0 + (k1 + k2) / 2.0
x0[4 * nn : 5 * nn, :] += dW
return x0
[docs]
def run(self, x0=None):
"""
Simulate the Jansen-Rit model.
Parameters
----------
x0: array [num_nodes, num_sim]
initial state
Returns
-------
dict: simulation results
t: array [n_step]
time
x: array [n_step, num_nodes, num_sim]
y1-y2 time series
"""
self.prepare_input()
x = self.initial_state if x0 is None else x0
self.integrator = self.euler if self.method == "euler" else self.heun
dt = self.dt
_xp = self.xp
nn = self.nn
ns = self.num_sim
decimate = self.decimate
t_end = self.t_end
t_cut = self.t_cut
tspan = _xp.arange(0, t_end, dt)
i_cut = int(_xp.where(tspan >= t_cut)[0][0])
n_step = int((len(tspan) - i_cut) / decimate)
y = np.zeros((n_step, nn, ns), dtype="f") # store in host
ii = 0
for i in tqdm.trange(len(tspan)):
x = self.integrator(x, tspan[i])
x_ = get_(x, self.engine, "f")
if (i >= i_cut) and (i % decimate == 0):
y[ii, :, :] = x_[nn : 2 * nn, :] - x_[2 * nn : 3 * nn, :]
ii += 1
t = get_(tspan[tspan >= t_cut][::decimate], self.engine, "f")
return {"t": t, "x": y}
[docs]
def set_initial_state(
nn, ns, engine, seed=None, same_initial_state=False, dtype="float"
):
"""
set initial state for the Jansen-Rit model
Parameters
----------
nn: int
number of nodes
ns: int
number of simulations
engine: str
cpu or gpu
seed: int
random seed
same_initial_state: bool
if True, all simulations have the same initial state
dtype: str
data type
Returns
-------
y0: array [nn, ns]
initial state
"""
if seed is not None:
np.random.seed(seed)
if same_initial_state:
y0 = np.random.uniform(-1, 1, (nn, 1))
y1 = np.random.uniform(-500, 500, (nn, 1))
y2 = np.random.uniform(-50, 50, (nn, 1))
y3 = np.random.uniform(-6, 6, (nn, 1))
y4 = np.random.uniform(-20, 20, (nn, 1))
y5 = np.random.uniform(-500, 500, (nn, 1))
y = np.vstack((y0, y1, y2, y3, y4, y5))
y = np.tile(y, (1, ns))
y = move_data(y, engine)
else:
y0 = np.random.uniform(-1, 1, (nn, ns))
y1 = np.random.uniform(-500, 500, (nn, ns))
y2 = np.random.uniform(-50, 50, (nn, ns))
y3 = np.random.uniform(-6, 6, (nn, ns))
y4 = np.random.uniform(-20, 20, (nn, ns))
y5 = np.random.uniform(-500, 500, (nn, ns))
y = np.vstack((y0, y1, y2, y3, y4, y5))
y = move_data(y, engine)
return y.astype(dtype)