Virtual Epileptic Patient

Open In Colab

[1]:
import os
import tqdm
import torch
import pickle
import numpy as np
import networkx as nx
import sbi.utils as utils
from vbi import report_cfg
import matplotlib.pyplot as plt
from vbi.utils import LoadSample
from sbi.analysis import pairplot
from vbi.models.numba.vep import VEP_sde
from vbi.sbi_inference import Inference
from sklearn.preprocessing import StandardScaler
[2]:
seed = 2
np.random.seed(seed)
torch.manual_seed(seed);
[3]:
path = "output/vep_numba"
os.makedirs(path, exist_ok=True)
[4]:
weights = np.loadtxt("data/weights1.txt")
nn = weights.shape[0]

healthy zone, propagation zone, epileptic zone eta values

[5]:
hz_val = -3.65
pz_val = -2.4
ez_val = -1.6
[6]:
ez_idx = np.array([6, 34], dtype=np.int32)
pz_wplng_idx = np.array([5, 11], dtype=np.int32)
pz_kplng_idx = np.array([27], dtype=np.int32)
pz_idx = np.append(pz_kplng_idx, pz_wplng_idx)
[7]:
eta_true = np.ones(nn) * hz_val
eta_true[pz_idx] = pz_val
eta_true[ez_idx] = ez_val
[8]:
initial_state = np.zeros(2 * nn)
initial_state[:nn] = -2.5
initial_state[nn:] = 3.5
# --------------------------------------------------------------------------- #
[9]:
params = {
    "G": 1.0,
    "seed": seed,
    "initial_state": initial_state,
    "weights": weights,
    "tau": 10.0,
    "eta": -3.5,
    "sigma": 0.0,
    "iext": 3.1,
    "dt": 0.1,
    "t_end": 14.0,
    "t_cut": 1.0,
    "record_step": 1,
    "method": "heun",
    "output": path,
}
[10]:
obj = VEP_sde(params)
g_true = 1.0
eta_true = [-1.6] * 2
eta_true_ = np.ones(nn) * hz_val
eta_true_[pz_idx] = pz_val
eta_true_[ez_idx] = ez_val
control_true = {"eta": eta_true_, "G": g_true}
[11]:
data = obj.run(par=control_true)
[12]:
print(obj.P.eta[:5])
print(obj.P.iext[:5])
[-3.65 -3.65 -3.65 -3.65 -3.65]
[3.1 3.1 3.1 3.1 3.1]
[13]:
ts = data["x"]
t = data["t"]
[14]:

plt.figure(figsize=(10, 16)) for i in range(0, nn): if i in ez_idx: plt.plot(t, ts[i, :] + i, "r", lw=3) elif i in pz_idx: plt.plot(t, ts[i, :] + i, "orange", lw=3) else: plt.plot(t, ts[i, :] + i, "g") plt.yticks(np.r_[0:nn] - 2, np.r_[0:nn], fontsize=10) plt.xticks(fontsize=16) plt.title("Source brain activity", fontsize=18) plt.xlabel("Time", fontsize=22) plt.ylabel("Brain Regions#", fontsize=22) plt.tight_layout() # plt.savefig("output/vep_sde.png", dpi=300)
../_images/examples_vep_sde_numba_15_0.png
[15]:
from vbi.feature_extraction.features_settings import *
from vbi.feature_extraction.calc_features import *
[16]:
fs = 1 / (params["dt"]) / 1000
cfg = get_features_by_domain(domain="statistical")
# cfg = get_features_by_given_names(cfg, names=["calc_moments"])
cfg = get_features_by_given_names(cfg, names=["auc"])
# report_cfg(cfg)
[17]:
data = extract_features_df([ts], fs, cfg=cfg, n_workers=1)
print(data.values.shape)

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1530.77it/s]
(1, 84)

[18]:
def wrapper(params, control, x0, cfg, verbose=False):
    vep_obj = VEP_sde(params)
    sol = vep_obj.run(control, x0=x0)

    # extract features
    fs = 1.0 / params["dt"] * 1000  # [Hz]
    stat_vec = extract_features(
        ts=[sol["x"]], cfg=cfg, fs=fs, n_workers=1, verbose=verbose
    ).values[0]
    return stat_vec
[19]:
def batch_run(params, control_list, x0, cfg, n_workers=1):
    n = len(control_list)
    def update_bar(_):
        pbar.update()
    with Pool(processes=n_workers) as pool:
        with tqdm.tqdm(total=n) as pbar:
            async_results = [
                pool.apply_async(
                    wrapper,
                    args=(params, control_list[i], x0, cfg, False),
                    callback=update_bar,
                )
                for i in range(n)
            ]
            stat_vec = [res.get() for res in async_results]
    return stat_vec
[20]:
num_sim = 1000
num_workers = 10
eta_min, eta_max = -5.0, -1.0
gmin, gmax = 0.0, 2.0
[21]:
prior_min = [gmin] + [eta_min] * 2
prior_max = [gmax] + [eta_max] * 2
[22]:
prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max))
[23]:
obj = Inference()
theta = obj.sample_prior(prior, num_sim)
theta_np = theta.numpy().astype(float)
[24]:
print(theta_np.shape)
(1000, 3)
[25]:
control_list = []
for i in range(num_sim):
    eta_ = np.ones(nn) * hz_val
    eta_[pz_idx] = pz_val
    eta_[ez_idx] = theta_np[i, 1:]
    g_ = theta_np[i, 0]
    control_list.append({"eta": eta_, "G": g_})
[ ]:
stat_vec = batch_run(params, control_list, initial_state, cfg, num_workers)
[27]:
scalar = StandardScaler()
stat_vec_st = scalar.fit_transform(np.array(stat_vec))
stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)
torch.save(theta, path+"/theta.pt")
torch.save(stat_vec, path+"/stat_vec.pt")
[28]:
print(theta.shape, stat_vec_st.shape)
torch.Size([1000, 3]) torch.Size([1000, 84])
[ ]:
posterior = obj.train(
    theta, stat_vec_st, prior, method="SNPE", density_estimator="maf", num_threads=8
)
[30]:
with open(path + "/posterior.pkl", "wb") as f:
    pickle.dump(posterior, f)
[31]:
# with open(path + "/posterior.pkl", "rb") as f:
#     posterior = pickle.load(f)
[32]:
xo = wrapper(params, control_true, initial_state, cfg)
xo_st = scalar.transform(xo.reshape(1, -1))
[33]:
samples = obj.sample_posterior(xo_st, 10000, posterior)
# torch.save(samples, "output/vep/samples.pt")
[34]:
limits = [[i, j] for i, j in zip(prior_min, prior_max)]
points = [[g_true] + eta_true]
fig, ax = pairplot(
    samples,
    limits=limits,
    figsize=(5, 5),
    points=points,
    labels=["G", "eta1", "eta2"],
    diag="kde",
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["g"]),
    diag_kwargs={"mpl_kwargs": {"color": "r"}},
)
ax[0, 0].tick_params(labelsize=14)
ax[0, 0].margins(y=0)
plt.tight_layout();
../_images/examples_vep_sde_numba_35_0.png
[ ]: