{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# [Jansen-Rit whole brain CuPy implementation](https://github.com/Ziaeemehr/vbi_paper/blob/main/docs/examples/jansen_rit_cupy.ipynb)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pickle\n", "import numpy as np\n", "import networkx as nx\n", "import sbi.utils as utils\n", "import matplotlib.pyplot as plt\n", "from sbi.analysis import pairplot\n", "from helpers import plot_ts_pxx_jr\n", "from vbi.inference import Inference\n", "from vbi.models.cupy.jansen_rit import JR_sde\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from vbi import report_cfg\n", "from vbi import extract_features_list\n", "from vbi import get_features_by_domain, get_features_by_given_names" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "seed = 2\n", "np.random.seed(seed)\n", "torch.manual_seed(seed);" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "LABESSIZE = 12\n", "plt.rcParams['axes.labelsize'] = LABESSIZE\n", "plt.rcParams['xtick.labelsize'] = LABESSIZE\n", "plt.rcParams['ytick.labelsize'] = LABESSIZE" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "nn = 6\n", "num_sim = 100\n", "num_workers = 10\n", "weights = nx.to_numpy_array(nx.complete_graph(nn))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "par = {\n", " \"weights\": weights,\n", " \"t_cut\": 500,\n", " \"t_end\": 2000,\n", " \"noise_amp\": 0.05,\n", " \"dt\": 0.02,\n", " \"num_sim\": num_sim,\n", " \"engine\": \"cpu\",\n", " \"seed\": seed,\n", " \"same_initial_state\": True,\n", "}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "obj = Inference()\n", "G_min, G_max = 0.0, 5.0\n", "C1_min, C1_max = 135, 300\n", "prior_min = [G_min, C1_min]\n", "prior_max = [G_max, C1_max]\n", "prior = utils.BoxUniform(low=torch.tensor(prior_min),\n", " high=torch.tensor(prior_max))\n", "theta = obj.sample_prior(prior, num_sim)\n", "theta_np = theta.numpy().astype(float)\n", "G = theta_np[:, 0]\n", "C1 = theta_np[:, 1]\n", "C1 = np.tile(C1, (nn, 1))\n", "par['G'] = G\n", "par['C1'] = C1" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "jr = JR_sde(par)\n", "# print(jr())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/100000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "if 1:\n", " ts0 = data['x'][:, :, 0].T\n", " data0 = {\"t\": t, \"x\": ts0}\n", " info = np.isnan(ts0).sum()\n", " print(t.shape, ts0.shape)\n", " fig, ax = plt.subplots(1, 2, figsize=(10, 3))\n", " plot_ts_pxx_jr(data0, par, ax, alpha=0.5)\n", " plt.tight_layout()\n", " plt.savefig(\"output/jr_ts_psd.png\", dpi=300)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Selected features:\n", "------------------\n", "■ Domain: statistical\n", " ▢ Function: calc_std\n", " ▫ description: Computes the standard deviation of the signal.\n", " ▫ function : vbi.feature_extraction.features.calc_std\n", " ▫ parameters : {'indices': None, 'verbose': False}\n", " ▫ tag : all\n", " ▫ use : yes\n", " ▢ Function: calc_mean\n", " ▫ description: Computes the mean of the signal.\n", " ▫ function : vbi.feature_extraction.features.calc_mean\n", " ▫ parameters : {'indices': None, 'verbose': False}\n", " ▫ tag : all\n", " ▫ use : yes\n" ] } ], "source": [ "cfg = get_features_by_domain(domain=\"statistical\")\n", "cfg = get_features_by_given_names(cfg, names=['calc_std', 'calc_mean'])\n", "report_cfg(cfg)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(100, 6, 75000)\n" ] } ], "source": [ "ts = data['x'] # [nt, nn, ns]\n", "ts = ts.transpose(2, 1, 0) # [ns, nn, nt]\n", "print(ts.shape)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(100, 12)\n" ] } ], "source": [ "from vbi import extract_features\n", "stat_vec = extract_features(ts=ts,\n", " cfg=cfg,\n", " fs=1/par['dt']*1000,\n", " n_workers=num_workers,\n", " verbose=False).values\n", "stat_vec = np.array(stat_vec)\n", "print(stat_vec.shape)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 2]) torch.Size([100, 12])\n" ] } ], "source": [ "scaler = StandardScaler()\n", "stat_vec_st = scaler.fit_transform(stat_vec)\n", "stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)\n", "torch.save(theta, 'output/theta.pt')\n", "torch.save(stat_vec_st, 'output/stat_vec_st.pt')\n", "print(theta.shape, stat_vec_st.shape)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Neural network successfully converged after 110 epochs.train Done in 0 hours 0 minutes 03.701637 seconds\n" ] } ], "source": [ "posterior = obj.train(theta, stat_vec_st, prior, method='SNPE', density_estimator='maf')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "with open('output/posterior.pkl', 'wb') as f:\n", " pickle.dump(posterior, f)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "with open('output/posterior.pkl', 'rb') as f:\n", " posterior = pickle.load(f)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "index = 0\n", "theta_true = theta[index, :]\n", "xo_st = stat_vec_st[index, :]" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cb99cec649924ade974642326c5a4c46", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Drawing 10000 posterior samples: 0%| | 0/10000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "limits = [[i, j] for i, j in zip(prior_min, prior_max)]\n", "fig, ax = pairplot(\n", " samples,\n", " limits=limits,\n", " figsize=(5, 5),\n", " points=theta_true,\n", " labels=[\"G\", \"C1\"],\n", " offdiag='kde',\n", " diag='kde',\n", " fig_kwargs=dict(\n", " points_offdiag=dict(marker=\"*\", markersize=10),\n", " points_colors=[\"g\"]),\n", " diag_kwargs={\"mpl_kwargs\": {\"color\": \"r\"}},\n", " upper_kwargs={\"mpl_kwargs\": {\"cmap\": \"Blues\"}},\n", ")\n", "ax[0,0].tick_params(labelsize=14)\n", "ax[0,0].margins(y=0)\n", "plt.tight_layout()\n", "fig.savefig(\"output/tri_jr_cupy.jpeg\", dpi=300)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }