Virtual Epileptic Patient¶
[ ]:
# Install VBI package in Google Colab (lightweight, CPU-only version)
print("Setting up VBI for Google Colab...")
# Skip C++ compilation for faster installation in Colab
%env SKIP_CPP=1
# Install the package
# !pip install vbi
print("VBI package installed successfully! Ready to proceed.")
[2]:
import os
import tqdm
import pickle
import numpy as np
import networkx as nx
from vbi import report_cfg
import autograd.numpy as anp
import matplotlib.pyplot as plt
from vbi.utils import LoadSample
from vbi.plot import pairplot_numpy
from vbi.models.numba.vep import VEP_sde
from vbi.cde import MAFEstimator, MDNEstimator
from sklearn.preprocessing import StandardScaler
from vbi.utils import posterior_shrinkage_numpy, posterior_zscore_numpy, posterior_peaks_numpy
[3]:
seed = 2
np.random.seed(seed)
[4]:
path = "output/tt_"
os.makedirs(path, exist_ok=True)
[ ]:
# Download the weights file
# !mkdir -p data
# !wget https://raw.githubusercontent.com/ins-amu/vbi/main/docs/examples/data/weights1.txt -O data/weights1.txt
[5]:
weights = np.loadtxt("data/weights1.txt")
nn = weights.shape[0]
healthy zone, propagation zone, epileptic zone eta values
[6]:
hz_val = -3.65
pz_val = -2.4
ez_val = -1.6
[7]:
ez_idx = np.array([6, 34], dtype=np.int32)
pz_wplng_idx = np.array([5, 11], dtype=np.int32)
pz_kplng_idx = np.array([27], dtype=np.int32)
pz_idx = np.append(pz_kplng_idx, pz_wplng_idx)
[8]:
eta_true = np.ones(nn) * hz_val
eta_true[pz_idx] = pz_val
eta_true[ez_idx] = ez_val
[9]:
initial_state = np.zeros(2 * nn)
initial_state[:nn] = -2.5
initial_state[nn:] = 3.5
[10]:
params = {
"G": 1.0,
"initial_state": initial_state,
"weights": weights,
"tau": 10.0,
"eta": -3.5,
"sigma": 0.0,
"iext": 3.1,
"dt": 0.1,
"t_end": 14.0,
"t_cut": 1.0,
"record_step": 1,
"method": "heun",
"output": path,
}
[11]:
obj = VEP_sde(params)
g_true = 1.0
eta_true = [-1.6] * 2
eta_true_ = np.ones(nn) * hz_val
eta_true_[pz_idx] = pz_val
eta_true_[ez_idx] = ez_val
control_true = {"eta": eta_true_, "G": g_true}
[12]:
data = obj.run(par=control_true)
[13]:
print(obj.P.eta[:5])
print(obj.P.iext[:5])
[-3.65 -3.65 -3.65 -3.65 -3.65]
[3.1 3.1 3.1 3.1 3.1]
[14]:
ts = data["x"]
t = data["t"]
[15]:
plt.figure(figsize=(8, 10))
for i in range(0, nn):
if i in ez_idx:
plt.plot(t, ts[i, :] + i, "r", lw=3)
elif i in pz_idx:
plt.plot(t, ts[i, :] + i, "orange", lw=3)
else:
plt.plot(t, ts[i, :] + i, "g")
plt.yticks(np.r_[0:nn] - 2, np.r_[0:nn], fontsize=10)
plt.xticks(fontsize=16)
plt.title("Source brain activity", fontsize=18)
plt.xlabel("Time", fontsize=22)
plt.ylabel("Brain Regions#", fontsize=22)
plt.tight_layout()
# plt.savefig("output/vep_sde.png", dpi=300)
[16]:
from vbi.feature_extraction.features_settings import *
from vbi.feature_extraction.calc_features import *
[17]:
fs = 1 / (params["dt"]) / 1000
cfg = get_features_by_domain(domain="statistical")
# cfg = get_features_by_given_names(cfg, names=["calc_moments"])
cfg = get_features_by_given_names(cfg, names=["auc", "seizure_onset"])
cfg = update_cfg(cfg, "seizure_onset", {"indices": ez_idx})
report_cfg(cfg)
Selected features:
------------------
■ Domain: statistical
▢ Function: auc
▫ description: Computes the area under the curve of the signal computed with trapezoid rule.
▫ function : vbi.feature_extraction.features.auc
▫ parameters : {'dx': None, 'x': None, 'indices': None, 'verbose': False}
▫ tag : all
▫ use : yes
▢ Function: seizure_onset
▫ description: Detects the onset of seizures.
▫ function : vbi.feature_extraction.features.seizure_onset
▫ parameters : {'threshold': 0.02, 'verbose': False, 'indices': array([ 6, 34], dtype=int32)}
▫ tag : ['eeg', 'seeg']
▫ use : yes
[18]:
data = extract_features_df([ts], fs, cfg=cfg, n_workers=1)
print(data.values.shape)
100%|██████████| 1/1 [00:00<00:00, 1129.63it/s]
(1, 86)
[19]:
def wrapper(params, control, x0, cfg, verbose=False):
vep_obj = VEP_sde(params)
sol = vep_obj.run(control, x0=x0)
# extract features
fs = 1.0 / params["dt"] * 1000 # [Hz]
stat_vec = extract_features(
ts=[sol["x"]], cfg=cfg, fs=fs, n_workers=1, verbose=verbose
).values[0]
return stat_vec
[20]:
def batch_run(params, control_list, x0, cfg, n_workers=1):
n = len(control_list)
def update_bar(_):
pbar.update()
with Pool(processes=n_workers) as pool:
with tqdm.tqdm(total=n) as pbar:
async_results = [
pool.apply_async(
wrapper,
args=(params, control_list[i], x0, cfg, False),
callback=update_bar,
)
for i in range(n)
]
stat_vec = [res.get() for res in async_results]
return stat_vec
[21]:
num_sim = 1000
num_workers = 10
eta_min, eta_max = -5.0, -1.0
gmin, gmax = 0.0, 2.0
[22]:
from vbi.utils import BoxUniform
prior_min = [gmin] + [eta_min] * 2
prior_max = [gmax] + [eta_max] * 2
prior = BoxUniform(low=prior_min, high=prior_max)
theta = prior.sample((num_sim))
select G and eta at ez_idx nodes as target parameters.
[23]:
control_list = []
for i in range(num_sim):
eta_ = np.ones(nn) * hz_val
eta_[pz_idx] = pz_val
eta_[ez_idx] = theta[i, 1:]
g_ = theta[i, 0]
control_list.append({"eta": eta_, "G": g_})
[24]:
stat_vec = batch_run(params, control_list, initial_state, cfg, num_workers)
0%| | 0/1000 [00:00<?, ?it/s]100%|██████████| 1000/1000 [00:00<00:00, 1118.39it/s]
[25]:
scalar = StandardScaler()
stat_vec = scalar.fit_transform(np.array(stat_vec))
[26]:
xo = wrapper(params, control_true, initial_state, cfg)
xo = scalar.transform(xo.reshape(1, -1))
print(theta.shape, stat_vec.shape, xo.shape)
(1000, 3) (1000, 86) (1, 86)
[27]:
rng = anp.random.RandomState(seed)
maf_estimator = MAFEstimator(n_flows=8, hidden_units=128)
maf_estimator.train(theta, stat_vec, n_iter=500, learning_rate=2e-4)
print("best epoch:", maf_estimator.best_epoch, "best val:", maf_estimator.best_val_loss)
samples = maf_estimator.sample(xo, n_samples=5000, rng=rng)[0]
Inferred dimensions: param_dim=3, feature_dim=86
Training: 100%|██████████| 500/500 [00:49<00:00, 10.13it/s, patience=2/20, train=-5.6401, val=-5.5616]
best epoch: 497 best val: -5.746348190269239
[28]:
with open(path + "/posterior.pkl", "wb") as f:
pickle.dump(maf_estimator, f)
[36]:
plt.figure(figsize=(5, 3))
plt.plot(maf_estimator.loss_history, color='royalblue', linewidth=2, marker='o', markersize=3, alpha=0.5)
plt.title("MAF Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()
[29]:
limits = [[i, j] for i, j in zip(prior_min, prior_max)]
points = [[g_true] + eta_true]
fig, ax = pairplot_numpy(
samples,
limits=limits,
figsize=(5, 5),
points=np.array(points).reshape(1, -1),
labels=["G", "eta1", "eta2"],
diag="kde",
fig_kwargs=dict(
points_offdiag=dict(marker="*", markersize=10),
points_colors=["g"]),
diag_kwargs={"mpl_kwargs": {"color": "r"}},
)
ax[0, 0].tick_params(labelsize=14)
ax[0, 0].margins(y=0)
plt.tight_layout();
[ ]: