Montbrio SDE model using Numba

[ ]:
import os
import vbi
import torch
import warnings
import numpy as np
import pandas as pd
import networkx as nx
from copy import deepcopy
import sbi.utils as utils
from vbi.utils import timer
import multiprocessing as mp
from tqdm import tqdm
import matplotlib.pyplot as plt
from multiprocessing import Pool
from vbi.models.numba.mpr import MPR_sde
from vbi.sbi_inference import Inference
# %matplotlib inline

warnings.simplefilter("ignore")
[2]:
seed= 42
np.random.seed(seed)

LABESSIZE = 10
plt.rcParams['axes.labelsize'] = LABESSIZE
plt.rcParams['xtick.labelsize'] = LABESSIZE
plt.rcParams['ytick.labelsize'] = LABESSIZE

path = "output/mpr_numba"
os.makedirs(path, exist_ok=True)
[3]:
# @timer
def wrapper(g, par):
    par = deepcopy(par)
    sde = MPR_sde(par)
    control = {"G":g}
    data = sde.run(control)
    rv_t = data["rv_t"]
    rv_d = data["rv_d"]
    nn = par["weights"].shape[0]

    if par["RECORD_RV"]:
        r = rv_d[:, :nn]
        v = rv_d[:, nn:]

    bold_d = data["bold_d"]
    bold_t = data["bold_t"]

    if par["RECORD_RV"]:
        return rv_t, r, v, bold_t, bold_d
    else:
        return bold_t, bold_d
[4]:
def batch_run(par, control_list, n_workers=1):
    """
    Run simulations in parallel with progress bar.

    Args:
        par: Parameters dictionary for the simulation
        control_list: List of control values (e.g., G values)
        n_workers: Number of parallel workers

    Returns:
        List of simulation results
    """
    n = len(control_list)

    def update_bar(_):
        pbar.update()

    with Pool(processes=n_workers) as pool:
        with tqdm(total=n, desc="Running simulations") as pbar:
            async_results = [pool.apply_async(wrapper,
                                              args=(control_list[i], par),
                                              callback=update_bar)
                             for i in range(n)]
            results = [res.get() for res in async_results]

    return results
[5]:
def plot(rv_t, r, v, bold_d, bold_t):
    step = 10
    fig, ax = plt.subplots(3, 1, figsize=(12, 6), sharex=True)
    ax[0].plot(rv_t[::step], r[::step, :], lw=0.1)
    ax[1].plot(rv_t[::step], v[::step, :], lw=0.1)
    ax[2].plot(bold_t, bold_d, lw=0.1)
    ax[0].set_ylabel("r")
    ax[1].set_ylabel("v")
    ax[2].set_ylabel("BOLD")
    ax[0].margins(x=0.01)
[6]:

nn = 6 weights = nx.to_numpy_array(nx.complete_graph(nn)) params = { "G": 0.01, "weights": weights, "t_end": 10000, "t_cut": 1000, "dt": 0.01, "tau": 1.0, "eta": np.array([-4.6]), "rv_decimate": 10, # in time steps "noise_amp": 0.037, "tr": 300.0, # in [ms] "seed": 42, "RECORD_BOLD": True, "RECORD_RV": True, }

warm up

[7]:
rv_t, r, v, bold_t, bold_d =  wrapper(0.33, params)
[8]:
# to check if there are any nans in the activities
np.isnan(r).sum()
[8]:
0
[9]:
params['t_end'] = 30_000
g = 0.33
rv_t, r, v, bold_t, bold_d = wrapper(g, params)
plot(rv_t, r, v, bold_d, bold_t)
../_images/examples_mpr_sde_numba_10_0.png
[10]:
np.diff(rv_t)[:2], np.diff(bold_t[:2]), rv_t[0], rv_t[1], rv_t[-1]
[10]:
(array([1., 1.], dtype=float32),
 array([300.], dtype=float32),
 1000.0,
 1001.0,
 29999.0)
[11]:
g = np.linspace(0.3, 0.35, 4, endpoint=True)
results = batch_run(params, g, n_workers=4)

Running simulations:   0%|          | 0/4 [00:00<?, ?it/s]
Running simulations:  25%|██▌       | 1/4 [00:02<00:07,  2.39s/it]
Running simulations: 100%|██████████| 4/4 [00:02<00:00,  1.65it/s]
[12]:
len(results), len(results[0])
[12]:
(4, 5)
[13]:
# for i in range(4):
#     plot(results[i][0], results[i][1], results[i][2], results[i][4], results[i][3])

Whole connectome

[14]:
D = vbi.LoadSample(nn=84)
weights = D.get_weights()
nn = weights.shape[0]
print(f"number of nodes: {nn}")

fig, ax = plt.subplots(1, 1, figsize=(4, 4.5))
ax.imshow(weights, cmap="gray", vmin=0, vmax=1);
number of nodes: 84
../_images/examples_mpr_sde_numba_17_1.png
[15]:
TR = 300.0
fs = 1 / (TR / 1000)
t_cut = 20
par = {
    "G": 0.7,  # global coupling strength
    "weights": weights,  # connection matrix
    "dt": 0.01,
    "t_cut": 20_000,
    "t_end": 100_000,  # [ms]
    "tr": TR,
    "rv_decimate": 10,
    "seed": seed,
    "RECORD_RV": True,
    "RECORD_BOLD": True,
}
[16]:
obj = MPR_sde(par)
sol = obj.run()
[17]:
rv_d = sol["rv_d"]
rv_t = sol["rv_t"] / 1000
fmri_d = sol["bold_d"]
fmri_t = sol["bold_t"] / 1000

rv_d = rv_d
rv_t = rv_t
fmri_d = fmri_d
fmri_t = fmri_t
print(np.isnan(fmri_d).sum(), np.isnan(rv_d).sum())

print(f"rv_t.shape = {rv_t.shape}")
print(f"rv_d.shape = {rv_d.shape}")
print(f"fmri_d.shape = {fmri_d.shape}")
print(f"fmri_t.shape = {fmri_t.shape}")
0 0
rv_t.shape = (80000,)
rv_d.shape = (80000, 168)
fmri_d.shape = (266, 84)
fmri_t.shape = (266,)
[18]:
fig, ax = plt.subplots(2, figsize=(15, 3.5), sharex=True)
ax[1].set_ylabel("BOLD")
ax[1].plot(fmri_t, fmri_d[:,:], lw=0.1)
ax[1].margins(0, 0.1)
ax[0].plot(rv_t, rv_d[:, :nn], lw=0.1, alpha=0.1)
ax[0].set_ylabel("r")
ax[1].set_xlabel("Time [s]")
ax[0].margins(0, 0.01)
plt.tight_layout()
plt.show()
../_images/examples_mpr_sde_numba_21_0.png

Feature extraction

[19]:
from vbi import (
    get_features_by_domain,
    get_features_by_given_names,
    report_cfg,
    extract_features,
)

cfg = get_features_by_domain("connectivity")
cfg = get_features_by_given_names(cfg, ["fcd_stat"])
report_cfg(cfg)
Selected features:
------------------
■ Domain: connectivity
 ▢ Function:  fcd_stat
   ▫ description:  Extracts features from dynamic functional connectivity (FCD)
   ▫ function   :  vbi.feature_extraction.features.fcd_stat
   ▫ parameters :  {'TR': 1.0, 'win_len': 30, 'positive': False, 'eigenvalues': True, 'masks': None, 'verbose': False, 'pca_num_components': 3, 'quantiles': [0.05, 0.25, 0.5, 0.75, 0.95], 'k': None, 'features': ['sum', 'max', 'min', 'mean', 'std', 'skew', 'kurtosis']}
   ▫ tag        :  ['fmri', 'eeg', 'meg']
   ▫ use        :  yes
[20]:
df = extract_features([fmri_d.T], fs, cfg, n_workers=10, output_type="dataframe", verbose=False)
df = df[["fcd_full_sum", "fcd_full_ut_std"]]
df
[20]:
fcd_full_sum fcd_full_ut_std
0 8933.09375 0.054202
[21]:
num_sim = 200
G_min, G_max = 0.0, 1.0

prior_min = [G_min]
prior_max = [G_max]
prior = utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)
)

obj = Inference()
theta = obj.sample_prior(prior, num_sim, seed=seed)
theta_np = theta.numpy().squeeze()
[22]:
theta_np.shape
[22]:
(200,)
[23]:
TR = 300.0
fs = 1 / (TR / 1000)
t_cut = 20
par = {
    "G": 0.506,  # global coupling strength
    "weights": weights,  # connection matrix
    "dt": 0.01,
    "t_cut": 20_000,
    "t_end": 100_000,  # [ms]
    "tr": TR,
    "rv_decimate": 10,
    "seed": seed,
    "RECORD_RV": False,
    "RECORD_BOLD": True,
}
[ ]:
results = batch_run(par, theta_np, n_workers=10)
[25]:
bolds = [res[1].T for res in results]
bolds = np.array(bolds)
bolds.shape
[25]:
(200, 84, 266)
[26]:
np.savez(path+"/bolds.npz", bolds=bolds, theta=theta_np)
# data = np.load(path+"/bolds.npz")
# bolds = data["bolds"]
# theta_np = data["theta"]
[ ]:
df = extract_features(bolds, fs, cfg, n_workers=10, output_type="dataframe", verbose=True)
df.head(2)
[28]:
X = df[["fcd_full_sum", "fcd_full_ut_std"]].values
X = torch.as_tensor(X, dtype=torch.float32)

Training the NN and building the posterior

[ ]:
obj_inf = Inference()
posterior = obj_inf.train(theta, X, prior=prior, num_threads=4)
[30]:
torch.save(posterior, os.path.join(path, "posterior.pt"))
np.savez(os.path.join(path, "data.npz"), theta=theta_np, X=X.numpy())
df.to_csv(os.path.join(path, "features.csv"), index=False)
[31]:
# loading data
posterior = torch.load(os.path.join(path, "posterior.pt"), weights_only=False)
data = np.load(os.path.join(path, "data.npz"))
theta_np = data["theta"]
X = torch.as_tensor(data["X"], dtype=torch.float32)
df = pd.read_csv(os.path.join(path, "features.csv"))

plotting feature distributions

[32]:
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
ax[0].scatter(theta_np, X[:, 0], s=10, alpha=0.5)
ax[0].set_xlabel("G")
ax[0].set_ylabel("fcd_full_sum")
ax[1].scatter(theta_np, X[:, 1], s=10, alpha=0.5)
ax[1].set_xlabel("G")
ax[1].set_ylabel("fcd_full_ut_std")
plt.tight_layout()
../_images/examples_mpr_sde_numba_38_0.png
  1. choose a true value,

  2. simulate for given configuration,

  3. extract features from observation point

  4. sample from posterior given observation point

  5. store data

[33]:
G_true = 0.72
# Single simulation for the true value
bold_obs = wrapper(G_true, par)[1].T

# bold_obs = np.load(path+"/bold_obs.npz")['bold']
# G_true = np.load(path+"/bold_obs.npz")["G"]

# Check if there are any NaNs in the observed BOLD data
assert(np.isnan(bold_obs).sum() == 0)

x_obs = extract_features(
    [bold_obs], fs, cfg, n_workers=10, output_type="dataframe", verbose=False
)

x_obs = x_obs[["fcd_full_sum", "fcd_full_ut_std"]].values
samples = obj_inf.sample_posterior(x_obs, 5000, posterior)
torch.save(samples, os.path.join(path, "samples.pt"))
[34]:
np.savez(os.path.join(path, "bold_obs.npz"), bold=bold_obs, G=G_true)

# data = np.load(os.path.join(path, "bold_obs.npz"))
# bold_obs = data["bold"]
# G_true = data["G"]
# samples = torch.load(os.path.join(path, "samples.pt"))

getting FC, FCD for visualisation

[35]:
from vbi.feature_extraction.features_utils import get_fc, get_fcd

fc = get_fc(bold_obs)['full']
fcd = get_fcd(bold_obs, win_len=25)['full']
[36]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

mosaic = """
AA
BC
"""
fig = plt.figure(constrained_layout=True, figsize=(8, 5))
ax_dict = fig.subplot_mosaic(mosaic)

# Plot data
ax_dict['A'].plot(bold_obs.T, lw=0.5, alpha=0.1)
im0 = ax_dict['B'].imshow(fcd, cmap="jet", vmin=-0.2, vmax=1)
im1 = ax_dict['C'].imshow(fc, cmap="jet", vmin=0, vmax=1)

divider0 = make_axes_locatable(ax_dict['B'])
cax0 = divider0.append_axes("right", size="5%", pad=0.05)
cbar0 = fig.colorbar(im0, cax=cax0)

divider1 = make_axes_locatable(ax_dict['C'])
cax1 = divider1.append_axes("right", size="5%", pad=0.05)
cbar1 = fig.colorbar(im1, cax=cax1)

# Set titles and labels
ax_dict['A'].set_title("BOLD")
ax_dict['B'].set_title("FCD")
ax_dict['B'].set_xlabel("Time window index")
ax_dict['B'].set_ylabel("Time window index")
ax_dict['C'].set_title("FC")
ax_dict['C'].set_xlabel("Node index")
ax_dict['C'].set_ylabel("Node index")

plt.show()

../_images/examples_mpr_sde_numba_44_0.png

plotting posterior samples and comparing with true value in green.

[37]:
from sbi.analysis import pairplot

limits = [[i, j] for i, j in zip(prior_min, prior_max)]
fig, ax = pairplot(
    samples,
    limits=limits,
    figsize=(5, 5),
    points=[G_true],
    labels=["G"],
    offdiag='kde',
    diag='kde',
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["g"]),
    diag_kwargs={"mpl_kwargs": {"color": "r"}},
)
plt.legend(["posterior", "True"], loc="upper left", fontsize=12, frameon=False)
fig.savefig(path+"/G.png", dpi=150)
../_images/examples_mpr_sde_numba_46_0.png
[ ]: