{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# [Damped Oscillator - numba](https://github.com/Ziaeemehr/vbi_paper/blob/main/docs/examples/do_nb.ipynb)\n",
"\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pickle\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from timeit import timeit\n",
"import sbi.utils as utils\n",
"import matplotlib.pyplot as plt\n",
"from multiprocessing import Pool\n",
"from sbi.analysis import pairplot\n",
"from vbi.inference import Inference\n",
"from sklearn.preprocessing import StandardScaler\n",
"from vbi.models.numba.damp_oscillator import DO_nb"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from vbi import report_cfg\n",
"from vbi import extract_features\n",
"from vbi import get_features_by_domain, get_features_by_given_names"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"seed = 2\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed);"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"params = {\n",
" \"a\": 0.1,\n",
" \"b\": 0.05,\n",
" \"dt\": 0.05,\n",
" \"t_start\": 0,\n",
" \"method\": \"heun\",\n",
" \"t_end\": 2001.0,\n",
" \"t_cut\": 500,\n",
" \"output\": \"output\",\n",
" \"initial_state\": [0.5, 1.0],\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"if 0:\n",
" ode = DO_nb(params)\n",
" control = {\"a\": 0.11, \"b\": 0.06}\n",
" t, x = ode.run(par=control)\n",
" plt.figure(figsize=(4, 3))\n",
" plt.plot(t, x[:, 0], label=\"$\\\\theta$\")\n",
" plt.plot(t, x[:, 1], label=\"$\\omega$\")\n",
" plt.xlabel(\"t\")\n",
" plt.ylabel(\"x\")\n",
" plt.legend()\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def func(par):\n",
" ode = DO_nb(params)\n",
" control = {\"a\": par[0], \"b\": par[1]}\n",
" t, x = ode.run(par=control)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"warm up"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average time for one run: 0.00765 s\n"
]
}
],
"source": [
"func([0.1, 0.05])\n",
"# timing\n",
"number = 1000\n",
"t = timeit(lambda: func([0.1, 0.05]), number=number)\n",
"print(f\"average time for one run: {t / number:.5f} s\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected features:\n",
"------------------\n",
"■ Domain: statistical\n",
" ▢ Function: calc_std\n",
" ▫ description: Computes the standard deviation of the signal.\n",
" ▫ function : vbi.feature_extraction.features.calc_std\n",
" ▫ parameters : {'indices': None, 'verbose': False}\n",
" ▫ tag : all\n",
" ▫ use : yes\n",
" ▢ Function: calc_mean\n",
" ▫ description: Computes the mean of the signal.\n",
" ▫ function : vbi.feature_extraction.features.calc_mean\n",
" ▫ parameters : {'indices': None, 'verbose': False}\n",
" ▫ tag : all\n",
" ▫ use : yes\n"
]
}
],
"source": [
"cfg = get_features_by_domain(domain=\"statistical\")\n",
"cfg = get_features_by_given_names(cfg, names=[\"calc_std\", \"calc_mean\"])\n",
"report_cfg(cfg)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def wrapper(params, control, cfg, verbose=False):\n",
" ode = DO_nb(params)\n",
" t, x = ode.run(par=control)\n",
"\n",
" # extract features\n",
" fs = 1.0 / params[\"dt\"] * 1000 # [Hz]\n",
" stat_vec = extract_features(\n",
" ts=[x.T], cfg=cfg, fs=fs, n_workers=1, verbose=verbose\n",
" ).values\n",
" return stat_vec[0]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def batch_run(par, control_list, cfg, n_workers=1):\n",
" def update_bar(_):\n",
" pbar.update()\n",
" stat_vec = []\n",
" with Pool(processes=n_workers) as p:\n",
" with tqdm(total=len(control_list)) as pbar:\n",
" asy_res = [\n",
" p.apply_async(wrapper, args=(par, control, cfg), callback=update_bar)\n",
" for control in control_list\n",
" ]\n",
" stat_vec = [res.get() for res in asy_res]\n",
" return stat_vec"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4.4408921e-16 2.2204460e-16 1.0530499e+00 8.8416451e-01]\n"
]
}
],
"source": [
"control = {\"a\": 0.11, \"b\": 0.06}\n",
"x_ = wrapper(params, control, cfg)\n",
"print(x_)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"num_sim = 2000\n",
"num_workers = 10\n",
"a_min, a_max = 0.0, 1.0\n",
"b_min, b_max = 0.0, 1.0\n",
"prior_min = [a_min, b_min]\n",
"prior_max = [a_max, b_max]\n",
"theta_true = {\"a\": 0.1, \"b\": 0.05}"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"prior = utils.torchutils.BoxUniform(\n",
" low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"obj = Inference()\n",
"theta = obj.sample_prior(prior, num_sim)\n",
"theta_np = theta.numpy().astype(float)\n",
"control_list = [{\"a\": theta_np[i, 0], \"b\": theta_np[i, 1]} for i in range(num_sim)]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2000/2000 [00:02<00:00, 859.29it/s]\n"
]
}
],
"source": [
"stat_vec = batch_run(params, control_list, cfg, n_workers=num_workers)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScaler()\n",
"stat_vec_st = scaler.fit_transform(np.array(stat_vec))\n",
"stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)\n",
"torch.save(theta, \"output/theta.pt\")\n",
"torch.save(stat_vec_st, \"output/stat_vec_st.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([2000, 2]), torch.Size([2000, 4]))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta.shape, stat_vec_st.shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Neural network successfully converged after 280 epochs.train Done in 0 hours 0 minutes 37.246640 seconds\n"
]
}
],
"source": [
"posterior = obj.train(\n",
" theta, stat_vec_st, prior, num_threads=8, method=\"SNPE\", density_estimator=\"maf\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"with open(\"output/posterior.pkl\", \"wb\") as f:\n",
" pickle.dump(posterior, f)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# with open(\"output/posterior.pkl\", \"rb\") as f:\n",
"# posterior = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"xo = wrapper(params, theta_true, cfg)\n",
"xo_st = scaler.transform(xo.reshape(1, -1))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11e262d75d0740b0b09d41bf3a29e14b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Drawing 10000 posterior samples: 0%| | 0/10000 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"samples = obj.sample_posterior(xo_st, 10000, posterior)\n",
"torch.save(samples, \"output/samples.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbgAAAHcCAYAAACkr7//AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHc9JREFUeJzt3XuMneV9J/DvmTu+jG2wATsYCA25IFLIAgZy2ZKNW7ppaZtdVTSpBGmbm7TNJnX+oagJq1YKXaWJ0ibsKoSNoBu1Tau0Sau2JBVpykbhEkJICWkJIcHmZhsTfAXP7Zz9YzzHczwe2zPzet7zvvP5SCPwmTOen4/OzFfP5fc8jVar1QoA1ExP2QUAwMkg4ACoJQEHQC0JOABqScABUEsCDoBaEnAA1JKAA6CW+sr85gdGD2TFzSuSJPt/d3+WDywvsxxYcn6251fLLgGO65+afzWvrzOCA6CWBBwAtSTgAKglAQdALQk4AGpJwAFQS10TcPf+aFfZJQBQI10TcPf9+CdllwBAjXRNwI2NN8suAYAa6ZqAGxFwABSoawJutCngAChO1wTc2Hir7BIAqJGuCThTlAAUqWsCbnR8ouwSAKiRLgo4IzgAitM1ATc2IeAAKE7XBNyogAOgQF0TcDaZAFCkrgk4a3AAFKlrAs5RXQAUqWsCzhQlAEXqmoCzyQSAInVPwBnBAVCg7gk4IzgACtQ9AWcEB0CBSg24iebhGwSarWTcKA6AgpQacEcez2WaEoCilBpwRwbayJiAA6AYRnAA1FKpATc+0XmLt40mABSlq0ZwTjMBoChdFnBu9QagGCUHnClKAE6OrhrBCTgAitJdAWcXJQAFKTfgxo3gADg5yg24ZucanF2UABSlu6YoBRwABdHoDUAtdVWbwIhNJgAUxBQlALXUXbcJOMkEgIIYwQFQSyVvMhFwAJwcXbXJRMABUJSuWoNzVBcARemqPriRMQEHQDG6aw3OCA6AgthFCUAtddUanMOWAShKd43gTFECUJCuahMYGXOSCQDF6KpdlEZwABSlq9bgbDIBoCjdNYITcAAUxCYTAGqpqwLOSSYAFKWrAs4IDoCilBtwTWtwAJwc3TWCE3AAFKTcNoHxmVOUrVZrlmcDwIkrNeAmmjPDzHmUABShq6YoExtNAChGuQF3lNGadTgAilDyUV2HpygHeidLEXAAFKHco7qmrcEN9Ak4AIrTNWtw/b2NJDaZAFCMrgm4QSM4AApUWsC1Wq2OC0/bU5QTLj0FYOFKC7jxI3rgpjaZmKIEoAilBdyRPXD9pigBKFB5ATd+xAiuzwgOgOKUFnBHnlgyqA8OgAJ10RRlbxIBB0Axuibg2ieZOIsSgAJ0TcBN9cGNjGkTAGDhyluDm2WTiREcAEUosQ/uiCnKvsmjuqzBAVCErpmidJsAAEXqninKqZNMTFECUIDuGcG1N5kIOAAWrmsCrt8mEwAK1DUBN6jRG4AClXhU1yxtAgIOgAKU1yZgFyUAJ1HXTFEevg/OSSYALFz3TVHaZAJAAUq8D+7oZ1GaogSgCF0zRdnvqC4ACtQ1AedGbwCK1D1rcD2mKAEojhEcALXUPX1wUyeZ2EUJQAFKHME5yQSAk6fENThtAgCcPF3TB9ff5yQTAIrTPZtMDu2ibLZmrs8BwFx1zRrcYP/hUmw0AWChumYNbuqw5cQ6HAAL1zVTlH29PemZPK1LwAGwYCX2wbVmPKbZG4CidM0UZTL9TjgBB4uup7fzAyqua6Yok2mnmQg4ABaoqwJu0KWnABSkr6xvPDY+cw3OaSZQnsbULq9DWn4MqbiuGsENOM0EgIJ01yYTIzgACtJdbQK9Ag6AYpS3Bne0TSb9NplAWVrj42WXAIXqrilKfXAAFKRLN5kIOAAWpmtuE0g0egNQnFICbqLZykTTJhMATp5SAu5o05PJtE0mAg6ABeqqgGuP4CY0egOwMKUE3NF64JLDR3WNjBnBAbAwpY7gjjj67vBJJvrgAFigUgJuKsD6ezu/vU0mABSlpBHc5BTlwBEBZ5MJAEUpdYqyr7dzjtJJJgAUpZwpyvGpKcojAu5Qo7eAA2ChSh3BzViDs8kEgIKU0ybQPPoa3OH74PTBAbAw5Yzgxo++BjfowlMAClJqm0DfLCM4a3AALFSpbQJHrsEN6oMDoCClbjKZdQ3OJhMAFqjkXZRHtgkYwQFQjL4yvuno+NHX4AZdeApdo9E/0P7/1thoiZXA/HTVGpxNJgAUpaQ+uKNPUa46pT9JcmB0PC+Oji96XQDUR8lHdXV++1OXD2TdysG0Wsm/b99XRmkA1ERXTVEmyQXrh5Mk339m76LWBEtd75o1HR+N3p72B1RRV90mkCQXbJgMuH97VsABMH/l9sEdeaV3ktdMjeAEHAALUE6bwCy3CSSHpygf3b4vE81Weo8SgkDxJl54oewSoFAlHbZ8aA2ub+a3f/na5Rnq78mLoxPZ+vyBxS4NgJooeQ1u5rfv7WnkVWeapgRgYcrtg5tl+vGC9SuT2GgCwPyV1Ac3+xRlolUAgIUr+bDlo3/7qZ2U//asZm8A5qfr2gSS5NWHAm773oP5yQGHvAIwd123ySRJVgz25dzTliWxDgfA/JSzBneMo7qmvMY6HAALUFIf3KE1uFk2mSTTNpoYwQEwD13ZJpBM32gi4KAMX3nmodz04D3Zd8vfZfdZu8ouB+as1CnKox22PGXq0OUf7tyfkfGJRakL6PTp7+3JfTtH8uSmx8suBeasa6co168ayqpT+jPebOWxHfsXqzTgkF0vTeSLP5r82dt5wVMZPWWk5Ipgbkrug5t9BNdoNNrrcPf+6PlFqQs47E9/sDfNycmWtBqtPHvR1nILgjkq5TaBw31wx87Xn7/wzNzzo+fzJ3c9ll+6aENOHx5ajPJgyTm48qWMLj/Y8dhH/uFgmmuSNCY/nrrs8ax5Yl3HcwYODGVo3ymLVyjMQUkBd/w2gST59cvPzhcffCr/+tSe/I+/eyT/69cvWYzyYMn53n+9L7vPOWIjSSuT4ZbJ/7605kDuf+9dHU9Z/cTaXHrHVYtQIcxdSZtMjt3oPaWvtyd/+F9+Or09jfzDw9vzlUe2L0Z5sORsePDl6RnrmQy1KUeuIEz/cyvpGevJhu+8fBGqg/kp96iuvuNfZnrBhuG89z+elyT5yJe/l70Hx05qbbAUbfjXc7Lp1s1Z9vyKpHmcJzeTZc+vzKZbN2fDv56zKPXBfJTTBzfVJnCcNbgp//0t5+fla5dnx96RfPhL32sHJFCcFbuGc/mtm3PGI2cd83lnPLIxl3/mLVmxa3iRKoP5KXWK8lhtAtMN9ffmo297bRqN5MsPPZNrP3NPnt790sksEZak3rG+rNm2rnOqcrpWsmbr2vSOl7J8D3Oy6AHXarWmHbZ8/CnKKVf+1Gn537/+H7JyqC8PbtudX/iT/5e7/m3HySoTlqy9619Io3noZ3Mq6A79t9FsZO+G3WWUBXO26AE30WyldeiH5XhtAkf6+QvX5+/f/6b89FmrsvvFsfzWHQ/kvf/3gfxgh3vjoCh7zvpJWr2tNCYa6Rnvydn3nJ+e8Z40Jhpp9bay5yx9qVTDogfcgZHDx24dr03gaM4+bVn+6n1X5jff8PL0NJKvPLIjV3/y7mz5wkN5TNDBgkz0TuTA2smfo1NeWJFNt27OK796UTbdujmnvLA8SXJg7b5M9Do+j+63qBPpu18czW/d8UCSZN3KwQz1987r7xns681Hrrkgb9+0MR//6g9y5yPb89ffeTp//Z2nc8H64fzSxRvyn159eob6etNoJI1GcurygSwbsG4Ax9Lsn8iKncNZ+eyavPofLm6vtU1tQPn3//yd7Dtzd5p9E+mdmN/PLyyWRqvVmm05uVDb9xzMdZ+7Lz/YsT/DQ335P++8LBdsGMyKm1ckSfb/7v4sH1g+r7/7u0/uzqe+9sN8/dGdGW/O/s8ZHurL+lWn5PThwaxbMZi1KwezdsVAzlx1Sl62eigbVp+S05YPpr+3kUbjxNcHoap+tudXZzzWSiuNGU1wJ/55KNo/Nf9qXl+34CHNli88dELPu/dHz+eZPQdzxvBg7vjNTXn1mcM5MHpgod8+SXLRxtW57fpL88KB0fzj97bnyw89nUee2ZtmqzX50Zzcubn34Hj2HtyXR48zldloTE6fDvb2pKenkd6eRnoaSU9j8v8n/9zIyPhEXhqdyMGxZgb6enL68GDOWDmU01YMpO/QVUCNhl8F3eAT115cdgmVcbx3rHc0VbHggPvr7zx9ws89b+3y/OlvbcpZa5Yt9Nse1ZrlA3nH5WfnHZefPeNz+w6OZfueg3l2z8Hs2Hswzx8Yza59I3lu/0ie3XMwz+x+Kdv3HMz4oU0wo+PNjI6feL/d6EQz+58bz4+eKya0KZaAg6VnwQF341tffULPO2WgL9f89PqsXjaw0G85LyuH+rNyqD/nn7Fy1udMNFvZf3A8IxMT7YBrtiZbG5qtyYtam81kotXKRLOVof6enNLfm6H+3hwcm8iOvSPZue9gdu0fTas1GZStWRuKADiZFhxw7/mPP1VEHV2ht6eRVcv6k/TP6+vPW7ei2IIAmLdSTjIBgJNNwAFQSwIOgFoScADUkoADoJYW7SQTAFhMRnAA1JKAA6CWBBwAtSTgAKglAQdALS3oLMpWq5V9+9yiTTWsXLnSPX+whCwo4Hbt2pXTTz+9qFrgpNq5c2fWrVtXdhnAIllQwA0MTF598+STT2Z4eLiQgqBoe/fuzcaNG9vvV2BpWFDATU33DA8PCzi6nulJWFpsMgGglgQcALW0oIAbHBzMTTfdlMHBwaLqgcJ5n8LS5LBlAGrJFCUAtSTgAKglAQdALQk4AGpJwAFQS/MKuG9961t561vfmtWrV2f58uW54oor8pd/+ZdF1wbz9vnPfz7vfe97c+mll2ZwcDCNRiO333572WUBi2jOR3X98z//c66++uoMDQ3l137t17Jy5cp88YtfzLXXXpsnn3wyH/rQh05GnTAnv/d7v5etW7dm7dq1Wb9+fbZu3Vp2ScAim9MIbnx8PO9+97vT09OTu+++O7feems+/vGP57vf/W5e+cpX5sYbb/SLhK5w22235Yknnshzzz2X973vfWWXA5RgTgH3ta99LY8//nje8Y535OKLL24/vmrVqtx4440ZHR3NHXfcUXSNMGebN2/OOeecU3YZQInmFHBf//rXkyQ/93M/N+NzV199dZLkX/7lXxZeFQAs0JwC7rHHHkuSnH/++TM+d+aZZ2bFihXt5wBAmeYUcHv27EkyOSV5NMPDw+3nAECZ9MEBUEtzCripkdtso7S9e/fOOroDgMU0p4CbWns72jrb9u3bs3///qOuzwHAYptTwP3Mz/xMkuSrX/3qjM995Stf6XgOAJRpTgH3lre8Jeedd17+7M/+LA899FD78T179uSjH/1oBgYGct111xVdIwDM2Zxv9J7tqK6tW7fmj/7ojxzVRVe47bbb8o1vfCNJ8vDDD+fBBx/MG97whrziFa9IkrzxjW/Mu971rjJLBE6yOQdcktx///256aab8s1vfjNjY2N57Wtfmy1btuTaa689GTXCnL3zne885qk6119/vcOXoebmFXAA0O30wQFQSwIOgFoScADUkoADoJYEHAC1JOAAqCUBB0AtCTgAaknAAVBLAg6AWhJwANSSgOsid955Z974xjdm9erVOe200/KLv/iLefzxx8suC6CSBFwXOXDgQLZs2ZIHHnggd911V3p6evK2t70tzWaz7NIAKsdtAl1s165dWbduXR5++OFceOGFZZcDUClGcF3ksccey9vf/vacd955GR4ezrnnnpsk2bZtW7mFAVRQX9kFcNg111yTc845J5/97GezYcOGNJvNXHjhhRkdHS27NIDKEXBd4vnnn8+jjz6az372s3nTm96UJPnGN75RclUA1SXgusSaNWty2mmn5dZbb8369euzbdu23HDDDWWXBVBZ1uC6RE9PT/7iL/4i3/72t3PhhRfmd37nd/Kxj32s7LIAKssuSgBqyQgOgFoScADUkoADoJYEHAC1JOAAqCUBB0AtCTgAaknAAVBLAg6AWhJwANSSgAOglgQcALUk4ACoJQEHQC0JOABqScABUEsCDoBaEnAA1JKAA6CW+souAOhOrVYrL469mCRZ1r8sjUaj5IpgbozggKN6cezFrLh5RVbcvKIddFAlAg6AWhJwANSSgAOglgQcALUk4ACoJW0CwAz7R8Zz3xO7yi4DFkTAATP8z3/899xx76PJKWVXAvNnihKY4ZndL5VdAiyYgANmGJ1oll0CLJiAA2YYGRdwVJ+AA2YYFXDUgIADZhBw1IGAA2awBkcdCDhghpHxibJLgAUTcMAMpiipAwEHzCDgqAMBB8wg4KgDAQfMYJMJdSDggA7NZitjE62yy4AFE3BAB6M36kLAAR0EHHUh4IAONphQFwIO6OCgZepCwAEdjOCoCwEHdBBw1IWAAzoIOOpCwAEdRicctEw9CDigg00m1IWAAzqYoqQuBBzQQcBRFwIO6OAkE+pCwAEdjOCoCwEHdLDJhLoQcEAHIzjqQsABHQQcdSHggA42mVAXAg7oYA2OuhBwQAdTlNSFgAM6CDjqQsABHRy2TF0IOKCDERx1IeCADgKOuhBwQAe7KKkLAQd0MIKjLgQc0EGjN3Uh4IAOpiipCwEHdJiaomw0Si4EFkjAAR2mAm7FQF/JlcDCCDigw9Qa3LKB3pIrgYURcECHqRHc8kEBR7UJOKBDO+CG+kuuBBZGwAEdRsYnz6JcYYqSihNwQIepEdwym0yoOAEHdJjaZLJiUMBRbQIOaGs2WxmbaCWxyYTqE3BA2/RjuozgqDoBB7RND7hlAo6KE3BA2/SbBJbbRUnFCTigbSrgBnp7MtAn4Kg2AQe0tQOuryf9vX49UG3ewUDb1BrcZMC5ToBqE3BA28jY4SnKwT6/Hqg272CgbXRi8piugb6e9BnBUXECDmibus170BocNeAdDLRN32QyIOCoOO9goM0uSurEOxhoa++i7BVwVJ93MNDWMUVpFyUV5x0MtI12bDKxi5JqE3BAW2ejt18PVJt3MNA2vdHbFCVV5x0MtHWM4HoO/3qYaLbKKgnmTcABbSPT2wSmjeDGpt0TB1Uh4IC2w5tMejsavUcFHBUk4IC2zkbvw7soBRxVJOCAtvZhy709aTQOB9y4gKOCBBzQNn0EN93YuE0mVI+AA9qmN3p3PH5oZAdVIuCAtultAh2PG8FRQQIOaGtPUR5xiok1OKpIwAFtI7OswdlFSRUJOKBttoDT6E0VCTigbXqjd8fjAo4KEnBAmzYB6kTAAW3Tb/SebmxCwFE9Ag5om3UEZ4qSChJwQNvsjd4CjuoRcEDb7I3eAo7qEXBA26yN3k0BR/UIOKDNGhx1IuCAJEmz2TJFSa0IOCBJ50aSmZtMtAlQPQIOSNIZcEeO4By2TBUJOCBJ5zTkkZtMTFFSRQIOSNK5g7LRaHR8ziYTqkjAAUlm30GZJGPaBKggAQckmb3JO3GjN9Uk4IAkszd5Jw5bppoEHJBk9stOk2TMJhMqSMABSZKR8Ykks0xR2mRCBQk4IMmxpyj1wVFFAg5IMu2qnP6jbTIRcFSPgAOSzH6bd5KMahOgggQckOQ4fXCmKKkgAQckmf027yQZ0wdHBQk4IMmxG72N4KgiAQckOfYuSptMqCIBByQ5TqO3ERwVJOCAJMfeZOLCU6pIwAFJpo3gentnfE6jN1Uk4IAkx270dtgyVSTggCTJ6MShsyiPtslkoplWS8hRLQIOSHLsNbjEKI7qEXBAkmM3eid2UlI9Ag5IcuxG70QvHNUj4IAkx270TozgqB4BByQ5dqN34tJTqkfAAUlsMqF+BByQ5Nj3wSXW4KgeAQckSUbGjjeCE3BUi4ADkhwewQ32zTyqa/rnoSoEHJDk+GtwpiipGgEHJNHoTf0IOCDJ8Ru9BRxVI+CAJMdv9DZFSdUIOCDJCazB6YOjYgQckFardfwpSiM4KkbAAR0tAI7qoi4EHNA+hzJx2DL1IeCAjg0ks7UJ2GRC1Qg4oGMHZaPROPpzjOCoGAEHZM9LY0mSof7ZfyWMjdtFSbUIOCAPbnshSXLhy1bN+hxrcFSNgAPyrR//JEly6bmnzvocU5RUjYAD8q0nJkdwm44VcDaZUDECDpa4p3e/lKd3v5TenkZed/bqWZ9nipKqEXCwxD3wxOT05IUbhrN8sG/W5xnBUTUCDpa4+09g/S0xgqN6BBwscd86NIK77LgBp02AahFwsITtfnE0P9ixP0ly2blrjvncEVOUVIyAgyXsgUO7J89btzynrRg85nNNUVI1Ag6WsKnpyWO1B0wRcFSNgIMl7P4TXH9L7KKkegQcLGHfe3pPkmTTy43gqB8BB0vY2EQrZwwP5qw1pxz3uaN2UVIxAg6WuMvOPXXWK3KmGx2fWIRqoDizH1sALAknMj2ZJI8/dyDn3vD3J7kamOmJP/yFeX2dgIMl7Ev/7Q3ZsHrouM87bXl/XjiwCAVBgQQcLGEXb1x9Qs/72ofenGZz4OQWAwUTcMBxDfT1ZPnAsRvBodvYZAJALQk4AGpJwAFQSwIOgFoScADUUqPVajl/B5ih1WrlxbEXkyTL+ped0Gkn0E0EHAC1ZIoSgFoScADUkoADoJYEHAC1JOAAqCUBB0AtCTgAaknAAVBLAg6AWhJwANSSgAOglvrKLgAoR6vVyr59+8ouA07IypUr53zgt4CDJWrXrl05/fTTyy4DTsjOnTuzbt26OX2NgIMlamBgIEny5JNPZnh4uORq6mHv3r3ZuHGj17RAU6/p1Pt1LgQcLFFT0z3Dw8N+GRfMa1q8+dxHaJMJALUk4ACoJQEHS9Tg4GBuuummDA4Oll1KbXhNi7eQ17TRarVaJ6EmACiVERwAtSTgAKglAQdALQk4AGpJwMESdcstt+Tcc8/N0NBQLr/88tx///1ll1QJc3ndbr/99jQajY6PoaGhRay2uu6+++5cc8012bBhQxqNRr70pS/N+e8QcLAEfeELX8iWLVty00035cEHH8xFF12Uq6++Ojt37iy7tK42n9dteHg4zz77bPtj69ati1hxdR04cCAXXXRRbrnllnn/HdoEYAm6/PLLc9lll+XTn/50kqTZbGbjxo15//vfnxtuuKHk6rrXXF+322+/PR/84Aeze/fuRa60XhqNRv7mb/4mv/IrvzKnrzOCgyVmdHQ03/72t7N58+b2Yz09Pdm8eXPuueeeEivrbvN93fbv359zzjknGzduzC//8i/nkUceWYxyiYCDJWfXrl2ZmJjIGWec0fH4GWecke3bt5dUVfebz+v2qle9Kp/73Ofy5S9/OZ///OfTbDbz+te/Pk899dRilLzkuU0A4CS58sorc+WVV7b//PrXvz6vec1r8pnPfCZ/8Ad/UGJlS4MRHCwxa9euTW9vb3bs2NHx+I4dO3LmmWeWVFX3K+J16+/vz+te97r88Ic/PBklcgQBB0vMwMBALrnkktx1113tx5rNZu66666O0QadinjdJiYm8vDDD2f9+vUnq0ymMUUJS9CWLVty/fXX59JLL82mTZvyyU9+MgcOHMhv/MZvlF1aVzve63bdddflZS97WW6++eYkye///u/niiuuyCte8Yrs3r07H/vYx7J169a8613vKvOfUQn79+/vGOn++Mc/zkMPPZRTTz01Z5999gn9HQIOlqBrr702zz33XD7ykY9k+/btufjii3PnnXfO2EBBp+O9btu2bUtPz+GJsRdeeCHvfve7s3379qxZsyaXXHJJvvnNb+aCCy4o659QGQ888EDe/OY3t/+8ZcuWJMn111+f22+//YT+Dn1wANSSNTgAaknAAVBLAg6AWhJwANSSgAOglgQcALUk4ACoJQEHQC0JOIAFuuqqq/LBD36w7DI4goADoJYEHAC1JOAACjA+Pp7f/u3fzqpVq7J27dp8+MMfjqN+yyXgAApwxx13pK+vL/fff3/++I//OJ/4xCdy2223lV3WkuY2AYAFuuqqq7Jz58488sgjaTQaSZIbbrghf/u3f5vvf//7JVe3dBnBARTgiiuuaIdbklx55ZV57LHHMjExUWJVS5uAA6CWBBxAAe67776OP9977705//zz09vbW1JFCDiAAmzbti1btmzJo48+mj//8z/Ppz71qXzgAx8ou6wlra/sAgDq4LrrrstLL72UTZs2pbe3Nx/4wAfynve8p+yyljS7KAGoJVOUANSSgAOglgQcALUk4ACoJQEHQC0JOABqScABUEsCDoBaEnAA1JKAA6CWBBwAtfT/AXtDd+n8IplBAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"limits = [[i, j] for i, j in zip(prior_min, prior_max)]\n",
"fig, ax = pairplot(\n",
" samples,\n",
" points=[list(theta_true.values())],\n",
" figsize=(5, 5),\n",
" limits=limits,\n",
" labels=[\"a\", \"b\"],\n",
" upper=\"kde\",\n",
" diag=\"kde\",\n",
" fig_kwargs=dict(\n",
" points_offdiag=dict(marker=\"*\", markersize=10),\n",
" points_colors=[\"g\"],\n",
" ),\n",
")\n",
"ax[0, 0].tick_params(labelsize=14)\n",
"ax[0, 0].margins(y=0)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "vbidevelop",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}