Damped Oscillator - Numba/C++

Open In Colab

[1]:
import torch
import pickle
import numpy as np
import sbi.utils as utils
import matplotlib.pyplot as plt
from multiprocessing import Pool
from sbi.analysis import pairplot
from vbi.sbi_inference import Inference
from sklearn.preprocessing import StandardScaler
from vbi.models.cpp.damp_oscillator import DO

# change to numba version is compatible with current pipeline
# from vbi.models.numba.damp_oscillator import DO
[2]:
from vbi import report_cfg
from vbi import extract_features
from vbi import get_features_by_domain, get_features_by_given_names
[3]:
seed = 2
np.random.seed(seed)
torch.manual_seed(seed);
[6]:
parameters = {
    "a": 0.1,
    "b": 0.05,
    "dt": 0.01,
    "t_start": 0,
    "method": "rk4",
    "t_end": 100.0,
    "t_cut": 20,
    "output": "output",
    "initial_state": [0.5, 1.0],
}
[7]:
ode = DO(parameters)
print(ode())
Damp Oscillator model
{'a': 0.1, 'b': 0.05, 'dt': 0.01, 't_start': 0, 'method': 'rk4', 't_end': 100.0, 't_cut': 20, 'output': 'output', 'initial_state': [0.5, 1.0]}
[8]:
sol = ode.run()
t = sol["t"]
x = sol["x"]
plt.figure(figsize=(4, 3))
plt.plot(t, x[:, 0], label="$\\theta$")
plt.plot(t, x[:, 1], label="$\omega$")
plt.xlabel("t")
plt.ylabel("x")
plt.legend()
plt.tight_layout()
plt.savefig("output/damp_oscillator_ts.jpeg", dpi=300)
../_images/examples_damp_oscillator_6_0.png
[9]:
cfg = get_features_by_domain(domain="statistical")
cfg = get_features_by_given_names(cfg, names=["calc_std", "calc_mean"])
report_cfg(cfg)
Selected features:
------------------
■ Domain: statistical
 ▢ Function:  calc_std
   ▫ description:  Computes the standard deviation of the signal.
   ▫ function   :  vbi.feature_extraction.features.calc_std
   ▫ parameters :  {'indices': None, 'verbose': False}
   ▫ tag        :  all
   ▫ use        :  yes
 ▢ Function:  calc_mean
   ▫ description:  Computes the mean of the signal.
   ▫ function   :  vbi.feature_extraction.features.calc_mean
   ▫ parameters :  {'indices': None, 'verbose': False}
   ▫ tag        :  all
   ▫ use        :  yes
[10]:
def wrapper(par, control, cfg, verbose=False):
    ode = DO(par)
    sol = ode.run(control)

    # extract features
    fs = 1.0 / par["dt"] * 1000  # [Hz]
    stat_vec = extract_features(
        ts=[sol["x"].T], cfg=cfg, fs=fs, n_workers=1, verbose=verbose
    ).values
    return stat_vec[0]
[11]:
def batch_run(par, control_list, cfg, n_workers=1):
    stat_vec = []
    with Pool(processes=n_workers) as pool:
        stat_vec = pool.starmap(
            wrapper, [(par, control, cfg) for control in control_list]
        )
    return stat_vec
[12]:
control = {"a": 0.11, "b": 0.06}
x_ = wrapper(parameters, control, cfg)
print(x_)
[0.12421611 0.10675827 1.0513536  0.8769057 ]
[13]:
num_sim = 2000
num_workers = 10
a_min, a_max = 0.0, 1.0
b_min, b_max = 0.0, 1.0
prior_min = [a_min, b_min]
prior_max = [a_max, b_max]
theta_true = {"a": 0.1, "b": 0.05}
[14]:
prior = utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)
)
[15]:
obj = Inference()
theta = obj.sample_prior(prior, num_sim)
theta_np = theta.numpy().astype(float)
control_list = [{"a": theta_np[i, 0], "b": theta_np[i, 1]} for i in range(num_sim)]
[16]:
stat_vec = batch_run(parameters, control_list, cfg, n_workers=4)
[17]:
scaler = StandardScaler()
stat_vec_st = scaler.fit_transform(np.array(stat_vec))
stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)
torch.save(theta, "output/theta.pt")
torch.save(stat_vec_st, "output/stat_vec_st.pt")
[18]:
theta.shape, stat_vec_st.shape
[18]:
(torch.Size([2000, 2]), torch.Size([2000, 4]))
[19]:
posterior = obj.train(theta, stat_vec_st, prior, method="SNPE", density_estimator="maf", num_threads=4)
 Neural network successfully converged after 315 epochs.train Done in 0 hours 0 minutes 39.261273 seconds
[20]:
with open("output/posterior.pkl", "wb") as f:
    pickle.dump(posterior, f)
[21]:
# with open("output/posterior.pkl", "rb") as f:
#     posterior = pickle.load(f)
[22]:
xo = wrapper(parameters, theta_true, cfg)
xo_st = scaler.transform(xo.reshape(1, -1))
[23]:
samples = obj.sample_posterior(xo_st, 10000, posterior)
torch.save(samples, "output/samples.pt")
[24]:
limits = [[i, j] for i, j in zip(prior_min, prior_max)]
fig, ax = pairplot(
    samples,
    points=[list(theta_true.values())],
    figsize=(5, 5),
    limits=limits,
    labels=["a", "b"],
    upper="kde",
    diag="kde",
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["g"],
    ),
)
ax[0, 0].tick_params(labelsize=14)
ax[0, 0].margins(y=0)
plt.tight_layout()
plt.savefig("output/do_cpp.jpeg", dpi=100)
../_images/examples_damp_oscillator_22_0.png