import numpy as np
from typing import Union
from copy import deepcopy
try:
from vbi.models.cpp._src.mpr_sde import MPR_sde as _MPR_sde
from vbi.models.cpp._src.mpr_sde import BoldParams as _BoldParams
except ImportError as e:
print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
[docs]
class MPR_sde:
"""
MPR model
"""
[docs]
def __init__(self, par: dict = {}, parbold={}) -> None:
par = deepcopy(par)
self._par = self.get_default_parameters()
self.valid_parameters = list(self._par.keys())
self.check_parameters(par)
self._par.update(par)
for item in self._par.items():
name = item[0]
value = item[1]
setattr(self, name, value)
if self.seed is not None:
np.random.seed(self.seed)
if self.initial_state is None:
self.INITIAL_STATE_SET = False
self.BP = BoldParams(parbold)
[docs]
def set_initial_state(self):
self.num_nodes = self.weights.shape[0]
self.initial_state = set_initial_state(self.num_nodes, self.seed)
self.INITIAL_STATE_SET = True
# -------------------------------------------------------------------------
def __str__(self) -> str:
print("MPR sde model.")
print("----------------")
for item in self._par.items():
name = item[0]
value = item[1]
print(f"{name} = {value}")
return ""
# -------------------------------------------------------------------------
def __call__(self):
return self._par
# -------------------------------------------------------------------------
[docs]
def check_parameters(self, par: dict):
for key in par.keys():
if key not in self.valid_parameters:
raise ValueError(f"Invalid parameter {key:s} provided.")
[docs]
def get_default_parameters(self):
params = {
"G": 0.733, # global coupling strength
"dt": 0.01, # for mpr model [ms]
"dt_bold": 0.001, # for Balloon model [s]
"J": 14.5, # model parameter
"eta": -4.6, # model parameter
"tau": 1.0, # model parameter
"delta": 0.7, # model parameter
"tr": 500.0, # sampling from mpr time series
"rv_decimate": 10, # sampling from activity time series
"noise_amp": 0.037, # amplitude of noise
"noise_seed": 0, # fix seed for noise
"iapp": 0.0, # constant applyed current
"seed": None,
"initial_state": None, # initial condition of the system
"t_cut": 0.0, # transition time [ms]
"t_end": 5 * 60 * 1000.0, # end time [ms]
"weights": None, # weighted connection matrix
"output": "output", # output directory
"RECORD_RV": 0, # true to store large time series in file
"RECORD_BOLD": 1,
}
return params
[docs]
def run(self, par: dict = {}, x0: np.ndarray = None, verbose: bool = False):
"""
Integrate the MPR model with the given parameters.
Parameters
----------
par : dict
Dictionary of parameters.
x0 : array_like
Initial condition of the system.
verbose : bool
If True, print the progress of the simulation.
Returns
-------
bold : array_like
Simulated BOLD signal.
"""
if x0 is None:
if not self.INITIAL_STATE_SET:
self.set_initial_state()
if verbose:
print("initial state set by default")
else:
assert len(x0) == self.num_nodes * 2
self.initial_state = x0
self.INITIAL_STATE_SET = True
for key in par.keys():
if key not in self.valid_parameters:
raise ValueError(f"Invalid parameter {key:s} provided.")
setattr(self, key, par[key])
self.prepare_input()
obj = _MPR_sde(
self.dt,
self.dt_bold,
self.rv_decimate,
self.weights,
self.initial_state,
self.delta,
self.tau,
self.eta,
self.J,
self.iapp,
self.noise_amp,
self.G,
self.t_end,
self.t_cut,
self.tr,
self.RECORD_RV,
self.RECORD_BOLD,
self.noise_seed,
self.BP.get_params()
)
obj.integrate()
bold_d = np.array([])
bold_t = np.array([])
r_d = np.array([])
r_t = np.array([])
if self.RECORD_BOLD:
bold_d = np.asarray(obj.get_bold_d()).astype(np.float32)
bold_t = np.asarray(obj.get_bold_t())
if bold_d.ndim == 2:
bold_d = bold_d[bold_t > self.t_cut, :]
bold_t = bold_t[bold_t > self.t_cut] * 10.0
if self.RECORD_RV:
r_d = np.asarray(obj.get_r_d()).astype(np.float32)
r_t = np.asarray(obj.get_r_t())
if r_d.ndim == 2:
r_d = r_d[r_t > self.t_cut, :]
r_t = r_t[r_t > self.t_cut] * 10.0
return {
"rv_t": r_t,
"rv_d": r_d,
"bold_t": bold_t,
"bold_d": bold_d,
}
[docs]
class BoldParams:
[docs]
def __init__(self, par={}):
self._par = self.get_default_parameters()
self.valid_parameters = list(self._par.keys())
self.check_parameters(par)
self._par.update(par)
for item in self._par.items():
name = item[0]
value = item[1]
setattr(self, name, value)
[docs]
def check_parameters(self, par):
for key in par.keys():
if key not in self.valid_parameters:
raise ValueError(f"Invalid parameter {key:s} provided.")
[docs]
def get_default_parameters(self):
return {
"kappa": 0.7,
"gamma": 0.5,
"tau": 1.0,
"alpha": 0.35,
"epsilon": 0.36,
"Eo": 0.42,
"TE": 0.05,
"vo": 0.09,
"r0": 26.0,
"theta0": 41.0,
"rtol": 1e-6,
"atol": 1e-9,
}
[docs]
def get_params(self):
bp = _BoldParams()
bp.kappa = self.kappa
bp.gamma = self.gamma
bp.tau = self.tau
bp.alpha = self.alpha
bp.epsilon = self.epsilon
bp.Eo = self.Eo
bp.TE = self.TE
bp.vo = self.vo
bp.r0 = self.r0
bp.theta0 = self.theta0
bp.rtol = self.rtol
bp.atol = self.atol
return bp
[docs]
def check_sequence(x: Union[int, float, np.ndarray], n: int):
"""
check if x is a scalar or a sequence of length n
parameters
----------
x: scalar or sequence of length n
n: number of nodes
returns
-------
x: sequence of length n
"""
if isinstance(x, (np.ndarray, list, tuple)):
assert len(x) == n, f" variable must be a sequence of length {n}"
return x
else:
return x * np.ones(n)
[docs]
def set_initial_state(nn, seed=None):
if seed is not None:
np.random.seed(seed)
y0 = np.random.rand(2 * nn)
y0[:nn] = y0[:nn] * 1.5
y0[nn:] = y0[nn:] * 4 - 2
return y0