Montbrio SDE model using NumbaΒΆ

[11]:
import os
import warnings
import numpy as np
import networkx as nx
from copy import deepcopy
import matplotlib.pyplot as plt
from vbi.models.numba.mpr import MPR_sde
from vbi.utils import timer

warnings.simplefilter("ignore")
[12]:
seed= 42
np.random.seed(seed)

LABESSIZE = 14
plt.rcParams['axes.labelsize'] = LABESSIZE
plt.rcParams['xtick.labelsize'] = LABESSIZE
plt.rcParams['ytick.labelsize'] = LABESSIZE
[23]:
# @timer
def wrapper(g, par):
    par = deepcopy(par)
    sde = MPR_sde(par)
    control = {"G":g}
    data = sde.run(control)
    rv_t = data["rv_t"]
    rv_d = data["rv_d"]
    nn = par["weights"].shape[0]
    r = rv_d[:, :nn]
    v = rv_d[:, nn:]

    bold_d = data["bold_d"]
    bold_t = data["bold_t"]

    return rv_t, r, v, bold_t, bold_d
[14]:
def plot(rv_t, r, v, bold_d, bold_t):
    step = 10
    fig, ax = plt.subplots(3, 1, figsize=(12, 6))
    ax[0].plot(rv_t[::step], r[::step, :], lw=0.1)
    ax[1].plot(rv_t[::step], v[::step, :], lw=0.1)
    ax[2].plot(bold_t, bold_d, lw=0.1)
    ax[0].set_ylabel("r")
    ax[1].set_ylabel("v")
    ax[2].set_ylabel("BOLD")
[15]:
nn = 6
weights = nx.to_numpy_array(nx.complete_graph(nn))
params = {"G": 0.01,
          "weights": weights,
          "t_end": 10000,
          "dt": 0.01,
          "tau": 1.0,
          "eta": np.array([-4.6]),
          "rv_decimate": 10, # in time steps
          "noise_amp": 0.037,
          "tr": 300.0,        # in [ms]
          "seed":42,
          "RECORD_BOLD": True,
          }

warm up

[16]:
rv_t, r, v, bold_t, bold_d =  wrapper(0.33, params)
wrapper Done in 0 hours 0 minutes 00.859522 seconds
[17]:
# to check if there are any nans in the activities
np.isnan(r).sum()
[17]:
0
[19]:
params['t_end'] = 30_000
g = 0.33
rv_t, r, v, bold_t, bold_d = wrapper(g, params)
plot(rv_t, r, v, bold_d, bold_t)
wrapper Done in 0 hours 0 minutes 02.785722 seconds
../_images/examples_mpr_sde_numba_9_1.png
[20]:
np.diff(rv_t)[:2], np.diff(bold_t[:2]), rv_t[0], rv_t[1], rv_t[-1]
[20]:
(array([1., 1.], dtype=float32),
 array([300.], dtype=float32),
 0.0,
 1.0,
 29999.0)
[31]:
import multiprocessing as mp

g = np.linspace(0.3, 0.35, 4, endpoint=True)
with mp.Pool(processes=4) as p:
    results = p.starmap(wrapper, [(g_, params) for g_ in g])

[34]:
len(results), len(results[0])
[34]:
(4, 5)
[ ]:
for i in range(4):
    plot(results[i][0], results[i][1], results[i][2], results[i][4], results[i][3])