{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# [Montbrio SDE model using Cupy](https://github.com/Ziaeemehr/vbi_paper/blob/main/docs/examples/mpr_sde_cupy.ipynb)\n", "\n", "Estimation of global coupling $G$.\n", "\n", "\"Open" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import vbi\n", "import torch\n", "import numpy as np\n", "import networkx as nx\n", "from copy import deepcopy\n", "import sbi.utils as utils\n", "import matplotlib.pyplot as plt\n", "from sbi.analysis import pairplot\n", "from vbi.sbi_inference import Inference\n", "from vbi.models.cupy.mpr import MPR_sde\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "seed = 42\n", "np.random.seed(seed)\n", "path = \"output\"\n", "os.makedirs(path, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "LABESSIZE = 10\n", "plt.rcParams['axes.labelsize'] = LABESSIZE\n", "plt.rcParams['xtick.labelsize'] = LABESSIZE\n", "plt.rcParams['ytick.labelsize'] = LABESSIZE" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "loading connectivity matrix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "D = vbi.LoadSample(nn=88)\n", "weights = D.get_weights()\n", "nn = weights.shape[0]\n", "print(f\"number of nodes: {nn}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Simulating BOLD single for a sample value of $G$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TR = 300.0\n", "fs = 1 / (TR / 1000)\n", "t_cut = 20\n", "par = {\n", " \"G\": 0.506, # global coupling strength\n", " \"weights\": weights, # connection matrix\n", " \"method\": \"heun\", # integration method\n", " \"dt\": 0.01,\n", " \"t_cut\": 20_000,\n", " \"t_end\": 100_000, # [ms]\n", " \"num_sim\": 1, # number of simulations\n", " \"tr\": TR,\n", " \"rv_decimate\": 10,\n", " \"engine\": \"cpu\", # cpu or gpu\n", " \"seed\": seed, # seed for random number generator\n", " \"RECORD_RV\": True,\n", " \"RECORD_BOLD\": True,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "obj = MPR_sde(par)\n", "# print(obj())\n", "sol = obj.run()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rv_d = sol[\"rv_d\"]\n", "rv_t = sol[\"rv_t\"] / 1000\n", "fmri_d = sol[\"fmri_d\"]\n", "fmri_t = sol[\"fmri_t\"] / 1000\n", "\n", "rv_d = rv_d\n", "rv_t = rv_t\n", "fmri_d = fmri_d\n", "fmri_t = fmri_t\n", "print(np.isnan(fmri_d).sum(), np.isnan(rv_d).sum())\n", "\n", "print(f\"rv_t.shape = {rv_t.shape}\")\n", "print(f\"rv_d.shape = {rv_d.shape}\")\n", "print(f\"fmri_d.shape = {fmri_d.shape}\")\n", "print(f\"fmri_t.shape = {fmri_t.shape}\")\n", "\n", "np.savez(\n", " path + \"/bold_obs.npz\", t=fmri_t, bold=np.transpose(fmri_d, (2, 1, 0)), theta=par[\"G\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if fmri_d.ndim == 3:\n", " fig, ax = plt.subplots(3, figsize=(10, 5), sharex=True)\n", " ax[0].set_ylabel(\"BOLD\")\n", " ax[0].plot(fmri_t, fmri_d[:,:,0], lw=0.1)\n", " ax[0].margins(0, 0.1)\n", " ax[1].plot(rv_t, rv_d[:, :nn, 0], lw=0.1, alpha=0.1)\n", " ax[2].plot(rv_t, rv_d[:, nn:, 0], lw=0.1, alpha=0.1)\n", " ax[1].set_ylabel(\"r\")\n", " ax[2].set_ylabel(\"v\")\n", " ax[2].set_xlabel(\"Time [s]\")\n", " ax[1].margins(0, 0.01)\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Training data**\n", "\n", "- Uniform prior for $G$ and sampling from prior;\n", "- Selecting GPU as engine;\n", "- Storing training BOLD signals;\n", "- Extracting features from the simulated BOLD signals;\n", "- Visualizing some of the features;\n", "- Training NN and estimating parameter of G for given observed signal;\n", "- Visualising the posterior distribution.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_sim = 512\n", "G_min, G_max = 0.0, 1.0\n", "\n", "prior_min = [G_min]\n", "prior_max = [G_max]\n", "prior = utils.torchutils.BoxUniform(\n", " low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)\n", ")\n", "\n", "obj = Inference()\n", "theta = obj.sample_prior(prior, num_sim, seed=seed)\n", "\n", "par_batch = deepcopy(par)\n", "par_batch['G'] = theta.numpy().astype(np.float64).squeeze()\n", "par_batch['num_sim'] = num_sim\n", "par_batch['engine'] = 'gpu'\n", "par_batch['RECORD_RV'] = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "obj = MPR_sde(par_batch)\n", "sol = obj.run()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fmri_d = sol[\"fmri_d\"]\n", "fmri_t = sol[\"fmri_t\"]\n", "fmri_d = fmri_d\n", "fmri_t = fmri_t\n", "bolds = np.transpose(fmri_d, (2, 1, 0))\n", "\n", "np.savez(path + \"/bolds.npz\", bolds=bolds, fmri_t=fmri_t, theta=theta.numpy().squeeze())\n", "print(f\"fmri_d.shape = {fmri_d.shape}\")\n", "print(f\"fmri_t.shape = {fmri_t.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bolds = np.load(path + \"/bolds.npz\")[\"bolds\"]\n", "theta = np.load(path + \"/bolds.npz\")[\"theta\"]\n", "theta = torch.tensor(theta).float()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from vbi import (\n", " get_features_by_domain,\n", " get_features_by_given_names,\n", " report_cfg,\n", " extract_features,\n", ")\n", "\n", "cfg = get_features_by_domain(\"connectivity\")\n", "cfg = get_features_by_given_names(cfg, [\"fcd_stat\"])\n", "report_cfg(cfg)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = extract_features(bolds, fs, cfg, n_workers=10, output_type=\"dataframe\")\n", "df = df[[\"fcd_full_sum\", \"fcd_full_ut_std\"]]\n", "df['G'] = theta.numpy().squeeze()\n", "df.to_csv(path + \"/g_cupy_features.csv\", index=False)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "LABELSIZE = 14\n", "plt.rc('axes', labelsize=LABELSIZE)\n", "plt.rc('axes', titlesize=LABELSIZE)\n", "plt.rc('figure', titlesize=LABELSIZE)\n", "plt.rc('legend', fontsize=LABELSIZE)\n", "plt.rc('xtick', labelsize=LABELSIZE)\n", "plt.rc('ytick', labelsize=LABELSIZE)\n", "\n", "f_kwargs = {\n", " \"lw\": 1,\n", " \"alpha\": 0.5,\n", " \"marker\": \"o\",\n", " \"linestyle\": \"\",\n", " \"markerfacecolor\": \"none\",\n", "}\n", "\n", "fig, ax = plt.subplots(1,2, figsize=(8, 3))\n", "ax[0].plot(df[\"G\"], df[\"fcd_full_sum\"], **f_kwargs)\n", "ax[1].plot(df[\"G\"], df[\"fcd_full_ut_std\"]**2, **f_kwargs)\n", "\n", "titles = [\"FCD sum\", \"Fluidity\"]\n", "for i in range(2):\n", " ax[i].set_xlabel(\"G\")\n", " ax[i].set_title(titles[i])\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# drop G column\n", "X = df.drop(columns=[\"G\"]).values\n", "X = torch.tensor(X, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "obj_inf = Inference()\n", "posterior = obj_inf.train(theta, X, prior=prior, num_threads=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Observation point" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bold_obs = np.load(\"bold_obs.npz\")['bold']\n", "x_obs = extract_features(bold_obs, 0.3, cfg, output_type=\"dataframe\")\n", "x_obs = x_obs[[\"fcd_full_sum\", \"fcd_full_ut_std\"]].values\n", "samples = obj_inf.sample_posterior(x_obs, 10000, posterior)\n", "\n", "limits = [[i, j] for i, j in zip(prior_min, prior_max)]\n", "fig, ax = pairplot(\n", " samples,\n", " points=[par['G']],\n", " figsize=(5, 5),\n", " limits=limits,\n", " labels=[\"G\"],\n", " diag=\"kde\",\n", " fig_kwargs=dict(\n", " points_offdiag=dict(marker=\"*\", markersize=10),\n", " points_colors=[\"g\"],\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }