{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# [Damped Oscillator - numba](https://github.com/Ziaeemehr/vbi_paper/blob/main/docs/examples/do_nb.ipynb)\n", "\n", "\"Open" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pickle\n", "import numpy as np\n", "from tqdm import tqdm\n", "from timeit import timeit\n", "import sbi.utils as utils\n", "import matplotlib.pyplot as plt\n", "from multiprocessing import Pool\n", "from sbi.analysis import pairplot\n", "from vbi.inference import Inference\n", "from sklearn.preprocessing import StandardScaler\n", "from vbi.models.numba.damp_oscillator import DO_nb" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from vbi import report_cfg\n", "from vbi import extract_features\n", "from vbi import get_features_by_domain, get_features_by_given_names" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "seed = 2\n", "np.random.seed(seed)\n", "torch.manual_seed(seed);" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "params = {\n", " \"a\": 0.1,\n", " \"b\": 0.05,\n", " \"dt\": 0.05,\n", " \"t_start\": 0,\n", " \"method\": \"heun\",\n", " \"t_end\": 2001.0,\n", " \"t_cut\": 500,\n", " \"output\": \"output\",\n", " \"initial_state\": [0.5, 1.0],\n", "}" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "if 0:\n", " ode = DO_nb(params)\n", " control = {\"a\": 0.11, \"b\": 0.06}\n", " t, x = ode.run(par=control)\n", " plt.figure(figsize=(4, 3))\n", " plt.plot(t, x[:, 0], label=\"$\\\\theta$\")\n", " plt.plot(t, x[:, 1], label=\"$\\omega$\")\n", " plt.xlabel(\"t\")\n", " plt.ylabel(\"x\")\n", " plt.legend()\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def func(par):\n", " ode = DO_nb(params)\n", " control = {\"a\": par[0], \"b\": par[1]}\n", " t, x = ode.run(par=control)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "warm up" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "average time for one run: 0.00765 s\n" ] } ], "source": [ "func([0.1, 0.05])\n", "# timing\n", "number = 1000\n", "t = timeit(lambda: func([0.1, 0.05]), number=number)\n", "print(f\"average time for one run: {t / number:.5f} s\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Selected features:\n", "------------------\n", "■ Domain: statistical\n", " ▢ Function: calc_std\n", " ▫ description: Computes the standard deviation of the signal.\n", " ▫ function : vbi.feature_extraction.features.calc_std\n", " ▫ parameters : {'indices': None, 'verbose': False}\n", " ▫ tag : all\n", " ▫ use : yes\n", " ▢ Function: calc_mean\n", " ▫ description: Computes the mean of the signal.\n", " ▫ function : vbi.feature_extraction.features.calc_mean\n", " ▫ parameters : {'indices': None, 'verbose': False}\n", " ▫ tag : all\n", " ▫ use : yes\n" ] } ], "source": [ "cfg = get_features_by_domain(domain=\"statistical\")\n", "cfg = get_features_by_given_names(cfg, names=[\"calc_std\", \"calc_mean\"])\n", "report_cfg(cfg)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def wrapper(params, control, cfg, verbose=False):\n", " ode = DO_nb(params)\n", " t, x = ode.run(par=control)\n", "\n", " # extract features\n", " fs = 1.0 / params[\"dt\"] * 1000 # [Hz]\n", " stat_vec = extract_features(\n", " ts=[x.T], cfg=cfg, fs=fs, n_workers=1, verbose=verbose\n", " ).values\n", " return stat_vec[0]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def batch_run(par, control_list, cfg, n_workers=1):\n", " def update_bar(_):\n", " pbar.update()\n", " stat_vec = []\n", " with Pool(processes=n_workers) as p:\n", " with tqdm(total=len(control_list)) as pbar:\n", " asy_res = [\n", " p.apply_async(wrapper, args=(par, control, cfg), callback=update_bar)\n", " for control in control_list\n", " ]\n", " stat_vec = [res.get() for res in asy_res]\n", " return stat_vec" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4.4408921e-16 2.2204460e-16 1.0530499e+00 8.8416451e-01]\n" ] } ], "source": [ "control = {\"a\": 0.11, \"b\": 0.06}\n", "x_ = wrapper(params, control, cfg)\n", "print(x_)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "num_sim = 2000\n", "num_workers = 10\n", "a_min, a_max = 0.0, 1.0\n", "b_min, b_max = 0.0, 1.0\n", "prior_min = [a_min, b_min]\n", "prior_max = [a_max, b_max]\n", "theta_true = {\"a\": 0.1, \"b\": 0.05}" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "prior = utils.torchutils.BoxUniform(\n", " low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "obj = Inference()\n", "theta = obj.sample_prior(prior, num_sim)\n", "theta_np = theta.numpy().astype(float)\n", "control_list = [{\"a\": theta_np[i, 0], \"b\": theta_np[i, 1]} for i in range(num_sim)]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2000/2000 [00:02<00:00, 859.29it/s]\n" ] } ], "source": [ "stat_vec = batch_run(params, control_list, cfg, n_workers=num_workers)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "stat_vec_st = scaler.fit_transform(np.array(stat_vec))\n", "stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)\n", "torch.save(theta, \"output/theta.pt\")\n", "torch.save(stat_vec_st, \"output/stat_vec_st.pt\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2000, 2]), torch.Size([2000, 4]))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "theta.shape, stat_vec_st.shape" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Neural network successfully converged after 280 epochs.train Done in 0 hours 0 minutes 37.246640 seconds\n" ] } ], "source": [ "posterior = obj.train(\n", " theta, stat_vec_st, prior, num_threads=8, method=\"SNPE\", density_estimator=\"maf\"\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "with open(\"output/posterior.pkl\", \"wb\") as f:\n", " pickle.dump(posterior, f)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# with open(\"output/posterior.pkl\", \"rb\") as f:\n", "# posterior = pickle.load(f)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "xo = wrapper(params, theta_true, cfg)\n", "xo_st = scaler.transform(xo.reshape(1, -1))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "11e262d75d0740b0b09d41bf3a29e14b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Drawing 10000 posterior samples: 0%| | 0/10000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "limits = [[i, j] for i, j in zip(prior_min, prior_max)]\n", "fig, ax = pairplot(\n", " samples,\n", " points=[list(theta_true.values())],\n", " figsize=(5, 5),\n", " limits=limits,\n", " labels=[\"a\", \"b\"],\n", " upper=\"kde\",\n", " diag=\"kde\",\n", " fig_kwargs=dict(\n", " points_offdiag=dict(marker=\"*\", markersize=10),\n", " points_colors=[\"g\"],\n", " ),\n", ")\n", "ax[0, 0].tick_params(labelsize=14)\n", "ax[0, 0].margins(y=0)\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "vbidevelop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }