{ "cells": [ { "cell_type": "markdown", "id": "ba11a91a", "metadata": {}, "source": [ "# Reduced Wong-Wang model - PyTorch\n", "Equivalent to the Numba notebook, adapted for PyTorch implementation" ] }, { "cell_type": "code", "execution_count": 1, "id": "52d5e97c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import vbi\n", "import warnings\n", "import numpy as np\n", "import pandas as pd\n", "from cmaes import CMA\n", "import networkx as nx\n", "from tqdm import tqdm\n", "from copy import deepcopy\n", "from vbi.utils import timer\n", "import multiprocessing as mp\n", "import matplotlib.pyplot as plt\n", "from multiprocessing import Pool\n", "from vbi.models.pytorch.rww import RWW_sde\n", "from vbi.feature_extraction.features_utils import get_fcd\n", "from vbi.models.pytorch.rww_sde_kong import WW_SDE_KONG" ] }, { "cell_type": "code", "execution_count": 2, "id": "bf0847ae", "metadata": {}, "outputs": [], "source": [ "seed = 42\n", "warnings.simplefilter(\"ignore\")\n", "np.random.seed(seed)" ] }, { "cell_type": "code", "execution_count": 3, "id": "93f96a64", "metadata": {}, "outputs": [], "source": [ "LABESSIZE = 10\n", "plt.rcParams[\"axes.labelsize\"] = LABESSIZE\n", "plt.rcParams[\"xtick.labelsize\"] = LABESSIZE\n", "plt.rcParams[\"ytick.labelsize\"] = LABESSIZE" ] }, { "cell_type": "code", "execution_count": 4, "id": "622b935c", "metadata": {}, "outputs": [], "source": [ "path = \"output/r_ww_sde_pytorch_\"\n", "os.makedirs(path, exist_ok=True)" ] }, { "cell_type": "markdown", "id": "8a86bdff", "metadata": {}, "source": [ "- Check compatibility with `rww_sde_kong`:" ] }, { "cell_type": "code", "execution_count": 5, "id": "5270ecea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "G: 6.270336254908696\n", "J: 0.2609\n", "w: (68,), type: float32\n", "s: (68,), type: float32\n", "I0: (68,), type: float32\n", "a: 270.0\n", "b: 108.0\n", "d: 0.154\n", "tau_s: 0.1\n", "gamma_s: 0.641\n", "t_end: 300.0\n", "t_cut: 120.0\n", "tr: 0.72\n", "dt: 0.01\n", "n_sim: 1\n", "weights: (68, 68), type: float32\n", "device: cpu\n", "dtype: torch.float64\n", "{'a', 'weights', 'J', 'dt', 'G', 'n_sim', 't_end', 'tau_s', 'device', 't_cut', 'b', 'tr', 'd', 'w'}\n", "{'integration_method', 'warmup_steps', 'I0', 'sigma', 'seed', 'gamma', 'I_ext', 'gamma_s', 'dtype', 'RECORD_BOLD', 's'}\n" ] } ], "source": [ "obj = WW_SDE_KONG()\n", "params = obj.get_default_params()\n", "\n", "# convert iterables to numpy arrays float32\n", "for key in params.keys():\n", " if hasattr(params[key], \"__len__\") and not isinstance(params[key], str):\n", " params[key] = np.array(params[key]).astype(np.float32)\n", "\n", "for key in params.keys():\n", " # print key and value shape if is iterable else print value\n", " if hasattr(params[key], \"__len__\") and not isinstance(params[key], str):\n", " print(f\"{key}: {np.array(params[key]).shape}, type: {(params[key].dtype)}\")\n", " else:\n", " print(f\"{key}: {params[key]}\")\n", "\n", "# --- update default parameters ---\n", "par = RWW_sde.get_default_parameters()\n", "common_keys = set(par.keys()).intersection(set(params.keys()))\n", "different_keys = set(par.keys()).symmetric_difference(set(params.keys()))\n", "\n", "print(common_keys)\n", "print(different_keys)\n", "\n", "for key in common_keys:\n", " par[key] = params[key]\n", "par['sigma'] = params['s']\n", "par['I_ext'] = params['I0']\n", "par['gamma'] = params['gamma_s']\n", "par['integration_method'] = 'euler'\n", "par['device'] = 'cuda'\n", "par['RECORD_BOLD'] = True\n", "par['warmup_steps'] = 1000" ] }, { "cell_type": "code", "execution_count": 6, "id": "415d3081", "metadata": {}, "outputs": [], "source": [ "def wrapper(par, **kwargs):\n", " \"\"\"Wrapper function to run the WW_sde model with given parameters.\"\"\"\n", " par = deepcopy(par)\n", " sde = RWW_sde(par)\n", " verbose = kwargs.get(\"verbose\", False)\n", " record_neural = kwargs.get(\"record_neural\", False)\n", " neural_subsample = kwargs.get(\"neural_subsample\", 5)\n", " data = sde.run(\n", " verbose=verbose, record_neural=record_neural, neural_subsample=neural_subsample\n", " )\n", " bold_d = data[\"bold_d\"].cpu().numpy() # Convert to numpy\n", " bold_t = data[\"bold_t\"].cpu().numpy()\n", "\n", " neural_d = None\n", " neural_t = None\n", " if record_neural:\n", " neural_d = data[\"S\"].cpu().numpy()\n", " neural_t = data[\"t\"].cpu().numpy()\n", "\n", " \n", " return bold_t, bold_d, neural_t, neural_d\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "6c27af3f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==========================================================================================\n", "Reduced Wong-Wang Neural Mass Model (PyTorch)\n", "==========================================================================================\n", "Network: 68 nodes, 1 parameter sets\n", "Device: cuda\n", "Integration method: euler\n", "\n", "Neural Model Parameters:\n", "------------------------------------------------------------------------------------------\n", "Parameter | Definition | Value/Shape \n", "------------------------------------------------------------------------------------------\n", "w | Recurrent strength | shape [68, 1] \n", "I_ext | External input | shape [68, 1] \n", "G | Global coupling | 6.27034 \n", "sigma | Noise amplitude | shape [68, 1] \n", "J | Synaptic coupling | 0.2609 \n", "a | Firing rate slope | 270 \n", "b | Firing rate threshold | 108 \n", "d | Firing rate scaling | 0.154 \n", "tau_s | Synaptic time constant | 0.1 \n", "gamma | Synaptic gating | 0.641 \n", "------------------------------------------------------------------------------------------\n", "weights | Structural connectivity | shape [68, 68] \n", "\n", "Simulation Settings:\n", "------------------------------------------------------------------------------------------\n", "Parameter | Definition | Value \n", "------------------------------------------------------------------------------------------\n", "t_cut | Warm-up time (dropped) | 120.0 s\n", "t_end | Total simulation time | 300.0 s\n", "t_record | Recording time (after drop) | 180.0 s\n", "dt | Integration time step | 0.01 s\n", "tr | BOLD sampling time (TR) | 0.72 s\n", "warmup_steps | Number of warm-up steps | 1000\n", "RECORD_BOLD | Record BOLD signals | True\n", "\n", "Hemodynamic Parameters (Balloon-Windkessel):\n", "------------------------------------------------------------------------------------------\n", "Parameter | Definition | Value \n", "------------------------------------------------------------------------------------------\n", "beta | Rate constant (flow/volume) | 0.65\n", "gamma | Rate constant (elimination) | 0.41\n", "tau | Hemodynamic transit time | 0.98\n", "alpha | Grubb exponent | 0.33\n", "p_constant | Resting oxygen extraction | 0.34\n", "v_0 | Resting blood volume fraction | 0.02\n", "k_1 | Signal coefficient | 4.10342\n", "k_2 | Signal coefficient | 0.581832\n", "k_3 | Signal coefficient | 0.53\n", "==========================================================================================\n" ] } ], "source": [ "ww = RWW_sde(par)\n", "print(ww)" ] }, { "cell_type": "code", "execution_count": 8, "id": "f2429529", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warm-up ...: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(bold_d.shape, bold_t.shape, neural_d.shape, neural_t.shape, np.isnan(bold_d).any()) # (nnodes, nsim, nsamples)\n", "fig, ax = plt.subplots(2, figsize=(12, 5))\n", "ax[0].plot(bold_t, bold_d[:,0,:].T, lw=0.5, alpha=1.0);\n", "ax[0].margins(x=0)\n", "ax[1].plot(neural_t, neural_d[:5,0,:].T, lw=0.3, alpha=1.0);\n", "ax[1].margins(x=0)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "id": "3d01e7b8", "metadata": {}, "outputs": [], "source": [ "def compute_cost(\n", " bold_d,\n", " fc_target=None,\n", " fcd_target=None,\n", " add_cost_fc=True,\n", " add_cost_fcd=True,\n", " zscore_fc=False,\n", " zscore_fcd=False,\n", "):\n", " \"\"\"Compute the cost between simulated and target BOLD signals.\n", "\n", " Parameters\n", " ----------\n", " bold_d : np.ndarray\n", " Simulated BOLD data. Shape: (n_timepoints, n_regions)\n", " \"\"\"\n", " # Compute costs here\n", " cost = 0.0\n", " return cost\n", "\n", "\n", "def wrapper(par, **kwargs):\n", " \"\"\"Wrapper function to run the WW_sde model with given parameters.\"\"\"\n", " par = deepcopy(par)\n", " sde = RWW_sde(par)\n", " verbose = kwargs.get(\"verbose\", False)\n", " record_neural = kwargs.get(\"record_neural\", False)\n", " get_cost = kwargs.get(\"get_cost\", False)\n", " neural_subsample = kwargs.get(\"neural_subsample\", 5)\n", " data = sde.run(\n", " verbose=verbose, record_neural=record_neural, neural_subsample=neural_subsample\n", " )\n", " bold_d = data[\"bold_d\"].cpu().numpy() # Convert to numpy\n", " bold_t = data[\"bold_t\"].cpu().numpy()\n", "\n", " neural_d = None\n", " neural_t = None\n", " if record_neural:\n", " neural_d = data[\"S\"].cpu().numpy()\n", " neural_t = data[\"t\"].cpu().numpy()\n", "\n", " if not get_cost:\n", " return bold_t, bold_d, neural_t, neural_d\n", " else:\n", " add_cost_fc = kwargs.get(\"add_cost_fc\", True)\n", " add_cost_fcd = kwargs.get(\"add_cost_fcd\", True)\n", " zscore_fc = kwargs.get(\"zscore_fc\", False)\n", " zscore_fcd = kwargs.get(\"zscore_fcd\", False)\n", " cost = compute_cost(\n", " bold_d,\n", " fc_target=None,\n", " fcd_target=None,\n", " add_cost_fc=add_cost_fc,\n", " add_cost_fcd=add_cost_fcd,\n", " zscore_fc=zscore_fc,\n", " zscore_fcd=zscore_fcd,\n", " )\n", "\n", " return cost" ] }, { "cell_type": "code", "execution_count": 11, "id": "97e4b1a0", "metadata": {}, "outputs": [], "source": [ "# # Gs = np.arange(0.1, 2.3, 0.05)\n", "# Gs = np.array([0.1])\n", "# par = dict(\n", "# weights=weights,\n", "# dt=2.5, # s\n", "# t_end=0.5 * 60 * 1000.0, # ms\n", "# t_cut=0.1 * 60 * 1000.0, # ms\n", "# G=Gs,\n", "# sigma=0.05, #0.05\n", "# I0=0.05, #0.05\n", "# tr=300.0, # s\n", "# device=\"cpu\",\n", "# n_sim=len(Gs),\n", "# integration_method=\"heun\",\n", "# )\n", "# ww = RWW_sde(par)\n", "# bold_t, bold_d, neural_t, neural_d = wrapper(\n", "# par, \n", "# verbose=True, \n", "# record_neural=True,\n", "# neural_subsample=2\n", "# )" ] }, { "cell_type": "code", "execution_count": null, "id": "23e2d39f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }