Source code for vbi.cde

"""
A module for conditional density estimation using Mixture Density
Networks (MDNs) and Masked Autoregressive Flows (MAFs).

Original Source:
    This file is based on code from the tvbl project:
    https://github.com/maedoc/tvbl/blob/main/content/cde.py

    Original Author: Marmaduke Woodman (maedoc)

Notes:
    - Adapted and modified for use in the VBI (Virtual Brain Inference) project
    - May include additional features, modifications, or optimizations

"""

import abc
import math
from dataclasses import dataclass, field

import autograd.numpy as anp
from autograd import grad
from autograd.scipy.special import logsumexp
from sklearn.datasets import make_moons
from scipy.stats import t
from tqdm.auto import trange
from typing import Optional


# =============================================================================
# == Base Class for Conditional Density Estimators
# =============================================================================


[docs] class ConditionalDensityEstimator(abc.ABC): """ Abstract base class for conditional density estimators. This class provides a unified training interface using the Adam optimizer and standardizes the API for training, sampling, and log-probability evaluation. Parameters ---------- param_dim : int, optional Dimensionality of the target variable (parameters to be estimated). If None, will be inferred from training data. feature_dim : int, optional Dimensionality of the conditional variable (features). If None, will be inferred from training data. """
[docs] def __init__(self, param_dim: int = None, feature_dim: int = None): self.param_dim = param_dim self.feature_dim = feature_dim self._dims_inferred = False self.weights = None self.loss_history = []
def _infer_dimensions(self, params: anp.ndarray, features: anp.ndarray): """ Infer parameter and feature dimensions from training data. Parameters ---------- params : anp.ndarray Parameter array of shape (N, param_dim) features : anp.ndarray Feature array of shape (N, feature_dim) """ # Convert to arrays and ensure 2D params = anp.asarray(params) features = anp.asarray(features) if params.ndim == 1: params = params.reshape(-1, 1) if features.ndim == 1: features = features.reshape(-1, 1) inferred_param_dim = params.shape[1] inferred_feature_dim = features.shape[1] # Check if user-provided dimensions match inferred ones if self.param_dim is not None and self.param_dim != inferred_param_dim: print( f"Warning: Provided param_dim ({self.param_dim}) doesn't match data ({inferred_param_dim}). Using data dimensions." ) if self.feature_dim is not None and self.feature_dim != inferred_feature_dim: print( f"Warning: Provided feature_dim ({self.feature_dim}) doesn't match data ({inferred_feature_dim}). Using data dimensions." ) # Set the inferred dimensions self.param_dim = inferred_param_dim self.feature_dim = inferred_feature_dim self._dims_inferred = True print( f"Inferred dimensions: param_dim={self.param_dim}, feature_dim={self.feature_dim}" ) # Validate inferred dimensions if self.param_dim <= 0: raise ValueError( f"Inferred param_dim must be positive, got {self.param_dim}" ) if self.feature_dim < 0: raise ValueError( f"Inferred feature_dim must be non-negative, got {self.feature_dim}" ) @abc.abstractmethod def _initialize_weights(self, rng: anp.random.RandomState) -> dict: """ Initialize the trainable weights of the model. Parameters ---------- rng : autograd.numpy.random.RandomState A random number generator for reproducible initialization. Returns ------- dict A dictionary of initialized weight arrays. """ pass @abc.abstractmethod def _loss_function( self, weights: dict, features: anp.ndarray, params: anp.ndarray ) -> float: """ Compute the negative log-likelihood loss for a batch of data. Parameters ---------- weights : dict A dictionary of the model's trainable weights. features : anp.ndarray A (N, feature_dim) array of conditional features. params : anp.ndarray A (N, param_dim) array of target parameters. Returns ------- float The mean negative log-likelihood of the batch. """ pass
[docs] @abc.abstractmethod def sample( self, features: anp.ndarray, n_samples: int, rng: anp.random.RandomState ) -> anp.ndarray: """ Generate samples from the learned conditional distribution p(params|features). Parameters ---------- features : anp.ndarray A (n_conditions, feature_dim) array of features to condition on. n_samples : int The number of samples to generate for each condition. rng : autograd.numpy.random.RandomState A random number generator for sampling. Returns ------- anp.ndarray An array of generated samples of shape (n_conditions, n_samples, param_dim). """ if self.weights is None: raise RuntimeError("Model has not been trained yet. Call train() first.") if not self._dims_inferred: raise RuntimeError("Model dimensions not inferred yet. Call train() first.")
[docs] @abc.abstractmethod def log_prob(self, features: anp.ndarray, params: anp.ndarray) -> anp.ndarray: """ Compute the log-probability log p(params|features). Parameters ---------- features : anp.ndarray A (N, feature_dim) array of conditional features. params : anp.ndarray A (N, param_dim) array of target parameters. Returns ------- anp.ndarray A (N,) array of log-probabilities. """ if self.weights is None: raise RuntimeError("Model has not been trained yet. Call train() first.") if not self._dims_inferred: raise RuntimeError("Model dimensions not inferred yet. Call train() first.")
############################################################################
[docs] def train( self, params: anp.ndarray, features: anp.ndarray, n_iter: int = 2000, learning_rate: float = 1e-3, seed: int = 0, use_tqdm: bool = True, patience: int = None, min_delta: float = 0.0, ): """ Trains the model using the Adam optimizer with optional early stopping for plateaus. Parameters ---------- params : anp.ndarray An (N, param_dim) matrix of simulated parameters. features : anp.ndarray An (N, feature_dim) matrix of corresponding data features. n_iter : int, optional The number of gradient descent iterations. learning_rate : float, optional The learning rate for the Adam optimizer. seed : int, optional Seed for reproducible weight initialization and training. use_tqdm : bool, optional If True, displays a progress bar during training. patience : int, optional Number of iterations with no improvement after which training will be stopped. If None, early stopping is disabled. min_delta : float, optional Minimum change in the loss to qualify as an improvement. """ # --- 1. Data Validation and Dimension Inference --- # Convert to arrays first params = anp.asarray(params) features = anp.asarray(features) # Infer dimensions if not already done if not self._dims_inferred: self._infer_dimensions(params, features) # Now validate with known dimensions if params.shape[0] != features.shape[0]: raise ValueError( "Params and features must have the same number of samples." ) if params.shape[1] != self.param_dim or features.shape[1] != self.feature_dim: raise ValueError( "Data dimensions do not match inferred/expected model dimensions." ) # Filter out non-finite values finite_idx = anp.all(anp.isfinite(params), axis=1) & anp.all( anp.isfinite(features), axis=1 ) params = params[finite_idx].astype("f") features = features[finite_idx].astype("f") if params.shape[0] == 0: raise ValueError("All data points contained non-finite values.") # --- 2. Initialization --- rng = anp.random.RandomState(seed) self.weights = self._initialize_weights(rng) self.loss_history = [] # Adam optimizer state m = {key: anp.zeros_like(val) for key, val in self.weights.items()} v = {key: anp.zeros_like(val) for key, val in self.weights.items()} beta1, beta2, epsilon = 0.9, 0.999, 1e-8 # Early stopping initialization best_loss = anp.inf counter = 0 if patience is not None and patience <= 0: raise ValueError("patience must be a positive integer.") # --- 3. Optimization Loop --- gradient_func = grad(self._loss_function) iterator = trange(n_iter, desc="Training", disable=not use_tqdm) for i in iterator: g = gradient_func(self.weights, features, params) loss = self._loss_function(self.weights, features, params) self.loss_history.append(loss) if not anp.isfinite(loss): print( f"Warning: Loss is non-finite at iteration {i}. Stopping training." ) break if use_tqdm: iterator.set_postfix(loss=f"{loss:.4f}") # Early stopping check if loss < best_loss - min_delta: best_loss = loss counter = 0 else: counter += 1 if patience is not None and counter >= patience: print(f"Early stopping at iteration {i} due to plateau.") break # Adam update step for key in self.weights: if not anp.all(anp.isfinite(g[key])): print( f"Warning: Non-finite gradient for '{key}' at iteration {i}. Stopping." ) return m[key] = beta1 * m[key] + (1 - beta1) * g[key] v[key] = beta2 * v[key] + (1 - beta2) * (g[key] ** 2) m_hat = m[key] / (1 - beta1 ** (i + 1)) v_hat = v[key] / (1 - beta2 ** (i + 1)) self.weights[key] -= learning_rate * m_hat / (anp.sqrt(v_hat) + epsilon)
# ============================================================================= # == MDN Implementation # =============================================================================
[docs] @dataclass class MDNEstimator(ConditionalDensityEstimator): """ Mixture Density Network for conditional density estimation. Parameters ---------- param_dim : int, optional Dimensionality of the target variable. If None, inferred from training data. feature_dim : int, optional Dimensionality of the conditional variable. If None, inferred from training data. n_components : int, optional The number of Gaussian mixture components. hidden_sizes : tuple[int, ...], optional A tuple specifying the number of units in each hidden layer. """ param_dim: int = None feature_dim: int = None n_components: int = 5 hidden_sizes: tuple[int, ...] = (32, 32) def __post_init__(self): super().__init__(self.param_dim, self.feature_dim) # Note: _offdiag_basis will be created after dimensions are inferred def _infer_dimensions(self, params: anp.ndarray, features: anp.ndarray): """Override to also create basis after dimension inference.""" super()._infer_dimensions(params, features) # Now that dimensions are known, create the off-diagonal basis self._offdiag_basis = self._create_offdiag_basis() def _create_offdiag_basis(self): n_off_diag = self.param_dim * (self.param_dim - 1) // 2 if n_off_diag == 0: return None basis = anp.zeros((n_off_diag, self.param_dim, self.param_dim), dtype="f") rows, cols = anp.triu_indices(self.param_dim, k=1) basis[anp.arange(n_off_diag), rows, cols] = 1 return basis def _initialize_weights(self, rng: anp.random.RandomState) -> dict: """Initializes weights for the MLP and GMM output layers.""" weights = {} in_size = self.feature_dim for i, out_size in enumerate(self.hidden_sizes): weights[f"W{i}"] = ( rng.randn(in_size, out_size) * anp.sqrt(2.0 / in_size) ).astype("f") weights[f"b{i}"] = anp.zeros(out_size, dtype="f") in_size = out_size last_hidden_size = ( self.hidden_sizes[-1] if self.hidden_sizes else self.feature_dim ) # GMM output layers K, D_out = self.n_components, self.param_dim weights["W_alpha"] = (rng.randn(last_hidden_size, K) * 0.01).astype("f") weights["b_alpha"] = anp.zeros(K, dtype="f") weights["W_mu"] = (rng.randn(last_hidden_size, K * D_out) * 0.01).astype("f") weights["b_mu"] = anp.zeros(K * D_out, dtype="f") weights["W_L_prec_log_diag"] = ( rng.randn(last_hidden_size, K * D_out) * 0.01 ).astype("f") weights["b_L_prec_log_diag"] = anp.zeros(K * D_out, dtype="f") n_off_diag = D_out * (D_out - 1) // 2 if n_off_diag > 0: weights["W_L_prec_offdiag"] = ( rng.randn(last_hidden_size, K * n_off_diag) * 0.01 ).astype("f") weights["b_L_prec_offdiag"] = anp.zeros(K * n_off_diag, dtype="f") return weights def _forward_pass(self, weights: dict, features: anp.ndarray): """Maps input features to GMM parameters.""" h = features for i in range(len(self.hidden_sizes)): h = anp.tanh(h @ weights[f"W{i}"] + weights[f"b{i}"]) K, D_out = self.n_components, self.param_dim log_alpha = h @ weights["W_alpha"] + weights["b_alpha"] alpha = anp.exp(log_alpha - logsumexp(log_alpha, axis=1, keepdims=True)) mu = (h @ weights["W_mu"] + weights["b_mu"]).reshape(-1, K, D_out) L_prec_log_diag = ( h @ weights["W_L_prec_log_diag"] + weights["b_L_prec_log_diag"] ).reshape(-1, K, D_out) L_prec_diag_mat = anp.einsum( "nki,ij->nkij", anp.exp(L_prec_log_diag), anp.eye(D_out, dtype="f") ) n_off_diag = D_out * (D_out - 1) // 2 if n_off_diag > 0: L_prec_offdiag_vals = ( h @ weights["W_L_prec_offdiag"] + weights["b_L_prec_offdiag"] ).reshape(-1, K, n_off_diag) L_prec_offdiag_mat = anp.einsum( "nkl,lij->nkij", L_prec_offdiag_vals, self._offdiag_basis ) L_prec = L_prec_diag_mat + L_prec_offdiag_mat else: L_prec = L_prec_diag_mat return alpha, mu, L_prec, L_prec_log_diag def _loss_function( self, weights: dict, features: anp.ndarray, params: anp.ndarray ) -> float: """Computes the negative log-likelihood of the true parameters under the GMM.""" alpha, mu, L_prec, L_prec_log_diag = self._forward_pass(weights, features) y_true_reshaped = params[:, anp.newaxis, :] delta = y_true_reshaped - mu z = anp.einsum("nkij,nkj->nki", L_prec, delta) quad_term = -0.5 * anp.sum(z**2, axis=2) log_det_term = anp.sum(L_prec_log_diag, axis=2) log_probs_k = ( quad_term + log_det_term - 0.5 * self.param_dim * anp.log(2 * math.pi) ) total_log_prob = logsumexp(anp.log(alpha + 1e-9) + log_probs_k, axis=1) return -anp.mean(total_log_prob)
[docs] def log_prob(self, features: anp.ndarray, params: anp.ndarray) -> anp.ndarray: """ Computes the log-probability log p(params|features) for each sample. """ super().log_prob(features, params) # Perform a forward pass to get GMM parameters alpha, mu, L_prec, L_prec_log_diag = self._forward_pass(self.weights, features) # Reshape parameters for broadcasting against mixture components y_true_reshaped = params[:, anp.newaxis, :] delta = y_true_reshaped - mu # Compute the log-probability for each component (k) for each sample (n) z = anp.einsum("nkij,nkj->nki", L_prec, delta) quad_term = -0.5 * anp.sum(z**2, axis=2) log_det_term = anp.sum(L_prec_log_diag, axis=2) log_probs_k = ( quad_term + log_det_term - 0.5 * self.param_dim * anp.log(2 * math.pi) ) # Combine component log-probabilities using the mixture weights (alpha) # This returns a vector of shape (N,) total_log_prob = logsumexp(anp.log(alpha + 1e-9) + log_probs_k, axis=1) return total_log_prob
[docs] def sample( self, features: anp.ndarray, n_samples: int, rng: anp.random.RandomState, log_prob_threshold: float = None, oversample_factor: int = 5, ) -> anp.ndarray: """ Generate samples from the learned conditional distribution p(params|features), with optional rejection sampling to filter low-probability outliers. Parameters ---------- features : anp.ndarray A (n_conditions, feature_dim) array of features to condition on. n_samples : int The number of samples to generate for each condition. rng : anp.random.RandomState A random number generator for sampling. log_prob_threshold : float, optional If provided, reject samples with log p(sample | features) < threshold. Tune based on training data (e.g., min log-prob - 2). oversample_factor : int, optional Multiplier for initial sample generation to account for rejections. Returns ------- anp.ndarray An array of generated samples of shape (n_conditions, n_samples, param_dim). If filtering is active and fewer than n_samples are retained, it will be padded with the last valid sample and a warning printed. """ super().sample(features, n_samples, rng) # Existing check features = features.astype("f") if features.ndim == 1: features = features.reshape(1, -1) n_cond, _, D_out = features.shape[0], features.shape[1], self.param_dim # Generate oversampled candidates n_candidates = n_samples * oversample_factor alpha, mu, L_prec, _ = self._forward_pass(self.weights, features) K = self.n_components log_alpha = anp.log(alpha + 1e-9) gumbel_noise = -anp.log(-anp.log(rng.uniform(size=(n_cond, n_candidates, K)))) component_indices = anp.argmax( log_alpha[:, anp.newaxis, :] + gumbel_noise, axis=2 ) cond_idx = anp.arange(n_cond)[:, anp.newaxis] chosen_mu = mu[cond_idx, component_indices] chosen_L_prec = L_prec[cond_idx, component_indices] try: L_cov_factor = anp.linalg.inv(chosen_L_prec) except anp.linalg.LinAlgError: print( "Warning: Singular precision matrix encountered during sampling. Returning NaNs." ) return anp.full((n_cond, n_samples, D_out), anp.nan) z = rng.randn(n_cond, n_candidates, D_out) candidate_samples = chosen_mu + anp.einsum("ncsi,ncs->nci", L_cov_factor, z) if log_prob_threshold is not None: # Compute log_probs for all candidates flat_features = anp.tile(features[:, anp.newaxis, :], (1, n_candidates, 1)).reshape(-1, self.feature_dim) flat_candidates = candidate_samples.reshape(-1, D_out) log_probs = self.log_prob(flat_features, flat_candidates) # Filter per condition valid_samples = [] for i in range(n_cond): cond_mask = anp.arange(n_candidates) + i * n_candidates cond_log_probs = log_probs[cond_mask] cond_candidates = flat_candidates[cond_mask] valid_mask = cond_log_probs > log_prob_threshold valid_cond = cond_candidates[valid_mask] if len(valid_cond) == 0: print(f"Warning: No valid samples for condition {i}. Using mean.") valid_cond = anp.tile(mu[i, 0], (n_samples, 1)) # Fallback to first component mean elif len(valid_cond) < n_samples: # Pad with last valid or repeat pad_len = n_samples - len(valid_cond) valid_cond = anp.concatenate([valid_cond, anp.tile(valid_cond[-1:], (pad_len, 1))]) print(f"Warning: Only {len(valid_cond) - pad_len} valid samples for condition {i}; padded.") else: valid_cond = valid_cond[:n_samples] valid_samples.append(valid_cond) filtered_samples = anp.stack(valid_samples) else: filtered_samples = candidate_samples[:, :n_samples, :] return filtered_samples
# ============================================================================= # == MAF Implementation # =============================================================================
[docs] @dataclass class MAFEstimator(ConditionalDensityEstimator): """ Masked Autoregressive Flow for conditional density estimation. Parameters ---------- param_dim : int, optional Dimensionality of the target variable. If None, inferred from training data. feature_dim : int, optional Dimensionality of the conditional variable. If None, inferred from training data. n_flows : int Number of autoregressive transforms (a.k.a. num_transforms). hidden_units : int Hidden features per MADE block (a.k.a. hidden_features). activation : str 'tanh' (default), 'relu', or 'elu'. z_score_theta : bool Standardize parameters (θ) internally. z_score_x : bool Standardize features (x) internally. use_actnorm : bool Insert ActNorm between flows with data-dependent init. embedding_dim : Optional[int] If set (E), PCA-reduce features x -> R^E before conditioning. """ param_dim: int = None feature_dim: int = None n_flows: int = 4 hidden_units: int = 64 activation: str = "tanh" z_score_theta: bool = True z_score_x: bool = True use_actnorm: bool = True embedding_dim: Optional[int] = None # PCA embedding for features actnorm_eps: float = 1e-6 # internal state (filled after prepare_* calls / training init) def __post_init__(self): super().__init__(self.param_dim, self.feature_dim) self._dims_inferred = False # Ensure attribute always exists self.model_constants = None # masks, perms, inv_perms self._actnorm_initialized = [False] * self.n_flows # Standardization stats self.theta_mean = None self.theta_std = None self.x_mean = None self.x_std = None # PCA embedding (optional) self._use_pca = False self._pca_components = None # (C, E) def _warmup_actnorm(self, features: anp.ndarray, params: anp.ndarray): """ Run one forward pass to initialize ActNorm with data-dependent stats, outside the autograd graph (avoids mutating weights during grad). """ if not self.use_actnorm: return # Use a small subset to estimate mean/std (like Glow’s data-dependent init) n = features.shape[0] k = min(512, n) # warmup batch size _ = self._get_log_prob(self.weights, features[:k], params[:k]) # After this call, self._actnorm_initialized[*] are True and actnorm params set. # ---------- public helpers: call these once before training ----------
[docs] def prepare_normalizers(self, features: anp.ndarray, params: anp.ndarray, rng=None): """Compute z-score stats (like sbi) and optional PCA projection for features.""" assert params.ndim == 2 and params.shape[1] == self.param_dim if self.feature_dim > 0: assert features.ndim == 2 and features.shape[1] == self.feature_dim if self.z_score_theta: self.theta_mean = anp.mean(params, axis=0) self.theta_std = anp.std(params, axis=0) + 1e-8 else: self.theta_mean = anp.zeros(self.param_dim) self.theta_std = anp.ones(self.param_dim) if self.feature_dim > 0: if self.z_score_x: self.x_mean = anp.mean(features, axis=0) self.x_std = anp.std(features, axis=0) + 1e-8 else: self.x_mean = anp.zeros(self.feature_dim) self.x_std = anp.ones(self.feature_dim) if self.embedding_dim is not None and self.embedding_dim < self.feature_dim: # PCA via SVD on standardized features X = (features - self.x_mean) / self.x_std U, S, Vt = anp.linalg.svd(X, full_matrices=False) E = self.embedding_dim self._pca_components = Vt[:E, :].T # (C, E) self._use_pca = True else: self._use_pca = False self._pca_components = None else: self.x_mean = anp.zeros(0) self.x_std = anp.ones(0) self._use_pca = False self._pca_components = None
# ---------- internal transforms ---------- def _act(self, x): if self.activation == "relu": return anp.maximum(0.0, x) elif self.activation == "elu": return anp.where(x > 0.0, x, anp.exp(x) - 1.0) else: return anp.tanh(x) def _z_theta(self, params): return (params - self.theta_mean) / self.theta_std def _inv_z_theta(self, z): return z * self.theta_std + self.theta_mean def _z_x(self, features): if self.feature_dim == 0: return features X = (features - self.x_mean) / self.x_std if self._use_pca: X = anp.dot(X, self._pca_components) # (N, E) return X def _ctx_dim(self): if self.feature_dim == 0: return 0 return ( self.embedding_dim if (self._use_pca and self.embedding_dim is not None) else self.feature_dim ) # ---------- parameters & constants ---------- def _initialize_weights(self, rng: anp.random.RandomState) -> dict: """Initializes weights, masks, permutations, and ActNorm (if enabled).""" weights = {} layers = [] D, C_in, H = self.param_dim, self._ctx_dim(), self.hidden_units for k in range(self.n_flows): # Degrees / masks (classic MADE) m_in = anp.arange(1, D + 1) # draw hidden degrees in [1, D] inclusive; use D+1 as high because # numpy randint is exclusive at the upper bound. This avoids # ValueError when D == 1 (low >= high). m_hidden = rng.randint(1, D + 1, size=H) M1 = (m_in[None, :] <= m_hidden[:, None]).astype("f") m_out = m_in.copy() M2 = (m_hidden[None, :] < m_out[:, None]).astype("f") # Permutation perm = rng.permutation(D) inv_perm = anp.empty(D, dtype=int) inv_perm[perm] = anp.arange(D) layers.append({"M1": M1, "M2": M2, "perm": perm, "inv_perm": inv_perm}) # Trainable parameters w_std = 0.01 weights[f"W1y_{k}"] = (rng.randn(H, D) * w_std).astype("f") weights[f"W1c_{k}"] = ( (rng.randn(H, C_in) * w_std).astype("f") if C_in > 0 else anp.zeros((H, C_in), dtype="f") ) weights[f"b1_{k}"] = anp.zeros(H, dtype="f") # Output heads (mu, log_scale) # Keep W2/W2c small; set log-scale bias negative (stable) weights[f"W2_{k}"] = anp.zeros((2 * D, H), dtype="f") weights[f"W2c_{k}"] = anp.zeros((2 * D, C_in), dtype="f") b2 = anp.zeros(2 * D, dtype="f") b2[D:] = -2.0 # log_sigma bias ~ exp(-2) start weights[f"b2_{k}"] = b2.astype("f") # ActNorm (per-dim scale & bias) if self.use_actnorm: weights[f"act_s_{k}"] = anp.ones(D, dtype="f") # scale weights[f"act_b_{k}"] = anp.zeros(D, dtype="f") # bias (pre-scale) else: weights[f"act_s_{k}"] = None weights[f"act_b_{k}"] = None self.model_constants = {"layers": layers} return weights # ---------- building blocks ---------- def _made_forward(self, y, ctx, layer_const, k, weights): """Single forward pass through a MADE block; returns mu, log_sigma.""" M1, M2 = layer_const["M1"], layer_const["M2"] W1y, W1c, b1 = weights[f"W1y_{k}"], weights[f"W1c_{k}"], weights[f"b1_{k}"] W2, W2c, b2 = weights[f"W2_{k}"], weights[f"W2c_{k}"], weights[f"b2_{k}"] y_h = anp.dot(y, (W1y * M1).T) c_h = anp.dot(ctx, W1c.T) if self._ctx_dim() > 0 else 0.0 h = self._act(y_h + c_h + b1) M2_tiled = anp.concatenate([M2, M2], axis=0) out = anp.dot(h, (W2 * M2_tiled).T) if self._ctx_dim() > 0: out = out + anp.dot(ctx, W2c.T) out = out + b2 mu = out[:, : self.param_dim] log_sigma = anp.clip(out[:, self.param_dim :], -7.0, 7.0) return mu, log_sigma def _apply_actnorm(self, u, k, weights, maybe_data_init=None): """ActNorm: y = (u + b) * s ; log_det += sum(log|s|).""" if not self.use_actnorm: return u, 0.0 s = weights[f"act_s_{k}"] b = weights[f"act_b_{k}"] # Data-dependent init (first batch) if not self._actnorm_initialized[k] and maybe_data_init is not None: m = anp.mean(maybe_data_init, axis=0) v = anp.std(maybe_data_init, axis=0) + self.actnorm_eps b = -m s = 1.0 / v weights[f"act_s_{k}"] = s.astype("f") weights[f"act_b_{k}"] = b.astype("f") self._actnorm_initialized[k] = True y = (u + b) * s log_abs_s = anp.log(anp.abs(s) + 1e-12) log_det = anp.sum(log_abs_s) # per-sample constant; broadcast by caller return y, log_det # ---------- core log_prob ---------- def _get_log_prob(self, weights: dict, features: anp.ndarray, params: anp.ndarray): """Computes log probability under the flow (with preprocessing).""" # Preprocess to sbi-like standardized spaces x = ( self._z_x(features).astype("f") if self._ctx_dim() > 0 else features.astype("f") ) u = self._z_theta(params).astype("f") batch = u.shape[0] log_det = anp.zeros(batch, dtype="f") for k, layer_const in enumerate(self.model_constants["layers"]): # Permute u = u[:, layer_const["perm"]] # ActNorm (data-dependent init on first batch seen) v, ln_det = self._apply_actnorm(u, k, weights, maybe_data_init=u) if self.use_actnorm: log_det = log_det + ln_det # add same constant per sample # MADE transform mu, log_sigma = self._made_forward(v, x, layer_const, k, weights) u = (v - mu) * anp.exp(-log_sigma) log_det = log_det - anp.sum(log_sigma, axis=1) base_logp = -0.5 * anp.sum(u**2, axis=1) - 0.5 * self.param_dim * anp.log( 2.0 * anp.pi ) return base_logp + log_det # ---------- public API ----------
[docs] def log_prob(self, features: anp.ndarray, params: anp.ndarray) -> anp.ndarray: super().log_prob(features, params) return self._get_log_prob(self.weights, features, params)
[docs] def sample( self, features: anp.ndarray, n_samples: int, rng: anp.random.RandomState ) -> anp.ndarray: """ Samples from p(theta | features). Returns shape (n_cond, n_samples, D) in original θ space. """ super().sample(features, n_samples, rng) if features.ndim == 1 and self.feature_dim > 0: features = features.reshape(1, -1) n_cond = 1 if self.feature_dim == 0 else features.shape[0] # Preprocess features x = ( self._z_x(features).astype("f") if self._ctx_dim() > 0 else ( features.astype("f") if self.feature_dim > 0 else anp.zeros((n_cond, 0), dtype="f") ) ) out = anp.zeros((n_cond, n_samples, self.param_dim), dtype="f") for c in range(n_cond): z = rng.randn(n_samples, self.param_dim).astype("f") y = z # Invert the flow stack (reverse order) for k, layer_const in reversed( list(enumerate(self.model_constants["layers"])) ): u_perm = y # current state in permuted coordinates we will fill autoregressively v = anp.zeros_like(u_perm) # Invert autoregressive transform sequentially for i in range(self.param_dim): mu, log_sigma = self._made_forward( v, x[c : c + 1].repeat(n_samples, axis=0), layer_const, k, self.weights, ) v[:, i] = u_perm[:, i] * anp.exp(log_sigma[:, i]) + mu[:, i] # Invert ActNorm: u = v / s - b if self.use_actnorm: s = self.weights[f"act_s_{k}"] b = self.weights[f"act_b_{k}"] v = v / (s + 1e-12) - b # Invert permutation y = v[:, layer_const["inv_perm"]] # Map back from z-space to original θ space out[c] = self._inv_z_theta(y) return out
# ---------- convenience to (re)build ----------
[docs] def reinitialize(self, rng: Optional[anp.random.RandomState] = None): """(Re)build masks/weights; call after prepare_normalizers().""" if rng is None: rng = anp.random.RandomState(0) self.weights = self._initialize_weights(rng) self._actnorm_initialized = [False] * self.n_flows
def _loss_function( self, weights: dict, features: anp.ndarray, params: anp.ndarray ) -> float: return -anp.mean(self._get_log_prob(weights, features, params))
[docs] def train( self, params: anp.ndarray, features: anp.ndarray, n_iter: int = 2000, learning_rate: float = 1e-3, seed: int = 0, use_tqdm: bool = True, # --- new knobs --- validation_fraction: float = 0.1, stop_after_epochs: int = 20, # patience (like sbi) early_stopping_delta: float = 0.0, # required improvement clip_max_norm: float = 5.0, # set None to disable ): import autograd.numpy as anp from autograd import grad try: from tqdm import trange except Exception: def trange(N, **kw): return range(N) # --- 1) arrays + infer dims (same checks as before) --- params = anp.asarray(params) features = anp.asarray(features) if not self._dims_inferred: self._infer_dimensions(params, features) if params.shape[0] != features.shape[0]: raise ValueError( "Params and features must have the same number of samples." ) if params.shape[1] != self.param_dim or features.shape[1] != self.feature_dim: raise ValueError( "Data dimensions do not match inferred/expected model dimensions." ) finite_idx = anp.all(anp.isfinite(params), axis=1) & anp.all( anp.isfinite(features), axis=1 ) params = params[finite_idx].astype("f") features = features[finite_idx].astype("f") if params.shape[0] == 0: raise ValueError("All data points contained non-finite values.") N = params.shape[0] rng_np = anp.random.RandomState(seed) # --- 2) train/val split --- if not (0.0 <= validation_fraction < 1.0): raise ValueError("validation_fraction must be in [0,1).") n_val = int(N * validation_fraction) perm = rng_np.permutation(N) val_idx = perm[:n_val] if n_val > 0 else anp.array([], dtype=int) train_idx = perm[n_val:] params_tr, feats_tr = params[train_idx], features[train_idx] params_val, feats_val = ( (params[val_idx], features[val_idx]) if n_val > 0 else (None, None) ) # --- 3) compute normalizers on TRAIN ONLY (sbi-style) --- self.prepare_normalizers(feats_tr, params_tr) # --- 4) init weights/masks/perms; reset actnorm flags --- rng = anp.random.RandomState(seed) self.weights = self._initialize_weights(rng) self.loss_history = [] self.val_loss_history = [] # <-- new self._actnorm_initialized = [False] * self.n_flows # --- 4.5) one-time ActNorm warmup on TRAIN subset --- self._warmup_actnorm(feats_tr, params_tr) # --- 5) Adam state --- m = {k: anp.zeros_like(v) for k, v in self.weights.items()} v = {k: anp.zeros_like(v) for k, v in self.weights.items()} beta1, beta2, epsilon = 0.9, 0.999, 1e-8 gradient_func = grad(self._loss_function) iterator = trange(n_iter, desc="Training", disable=not use_tqdm) # early stopping bookkeeping best_weights = {k: w.copy() for k, w in self.weights.items()} best_val = anp.inf if n_val > 0 else None epochs_no_improve = 0 self.best_epoch = -1 self.best_val_loss = None for epoch in iterator: # ---- forward/backward on TRAIN ---- g = gradient_func(self.weights, feats_tr, params_tr) train_loss = self._loss_function(self.weights, feats_tr, params_tr) self.loss_history.append(float(train_loss)) # (optional) grad clipping by global norm if clip_max_norm is not None: # compute global L2 norm over all tensors total_sq = 0.0 for key in g: total_sq += anp.sum(g[key] ** 2) global_norm = anp.sqrt(total_sq + 1e-12) if global_norm > clip_max_norm: scale = clip_max_norm / (global_norm + 1e-12) for key in g: g[key] = g[key] * scale # Adam update for key in self.weights: if not anp.all(anp.isfinite(g[key])): print( f"Warning: Non-finite gradient for '{key}' at epoch {epoch}. Stopping." ) self.weights = best_weights # rollback to best known return m[key] = beta1 * m[key] + (1 - beta1) * g[key] v[key] = beta2 * v[key] + (1 - beta2) * (g[key] ** 2) m_hat = m[key] / (1 - beta1 ** (epoch + 1)) v_hat = v[key] / (1 - beta2 ** (epoch + 1)) self.weights[key] -= learning_rate * m_hat / (anp.sqrt(v_hat) + epsilon) # ---- validation & early stopping ---- if n_val > 0: val_loss = self._loss_function(self.weights, feats_val, params_val) self.val_loss_history.append(float(val_loss)) improved = (best_val - val_loss) > early_stopping_delta if improved: best_val = float(val_loss) best_weights = {k: w.copy() for k, w in self.weights.items()} epochs_no_improve = 0 self.best_epoch = int(epoch) self.best_val_loss = float(val_loss) else: epochs_no_improve += 1 if use_tqdm: try: iterator.set_postfix( train=f"{train_loss:.4f}", val=f"{val_loss:.4f}", patience=f"{epochs_no_improve}/{stop_after_epochs}", ) except Exception: pass if epochs_no_improve >= stop_after_epochs: # restore best weights and stop self.weights = best_weights break else: if use_tqdm: try: iterator.set_postfix(train=f"{train_loss:.4f}") except Exception: pass # If we never saw validation or never improved, best_weights is initial; # in val-enabled runs we already restored on break; ensure final restore here too. if n_val > 0: self.weights = best_weights
[docs] @dataclass class MAFEstimator0(ConditionalDensityEstimator): """ Masked Autoregressive Flow for conditional density estimation. Parameters ---------- param_dim : int Dimensionality of the target variable. feature_dim : int Dimensionality of the conditional variable. n_flows : int, optional The number of flow layers (MADE blocks). hidden_units : int, optional The number of hidden units in each MADE block. """ param_dim: int feature_dim: int n_flows: int = 4 hidden_units: int = 64 def __post_init__(self): super().__init__(self.param_dim, self.feature_dim) self.model_constants = None # For non-trainable parts like masks def _initialize_weights(self, rng: anp.random.RandomState) -> dict: """Initializes weights and model constants (masks, permutations).""" weights = {} layers = [] D, C, H = self.param_dim, self.feature_dim, self.hidden_units for k in range(self.n_flows): # MADE masks and permutation m_in = anp.arange(1, D + 1) m_hidden = rng.randint(1, D+1, size=H) M1 = (m_in[None, :] <= m_hidden[:, None]).astype('f') m_out = m_in.copy() M2 = (m_hidden[None, :] < m_out[:, None]).astype('f') perm = rng.permutation(D) inv_perm = anp.empty(D, dtype=int); inv_perm[perm] = anp.arange(D) layers.append({'M1': M1, 'M2': M2, 'perm': perm, 'inv_perm': inv_perm}) # Trainable parameters w_std = 0.01 weights[f'W1y_{k}'] = (rng.randn(H, D) * w_std).astype('f') weights[f'W1c_{k}'] = (rng.randn(H, C) * w_std).astype('f') if C > 0 else anp.zeros((H, C), dtype='f') weights[f'b1_{k}'] = anp.zeros(H, dtype='f') weights[f'W2_{k}'] = anp.zeros((2 * D, H), dtype='f') weights[f'W2c_{k}'] = anp.zeros((2 * D, C), dtype='f') if C > 0 else anp.zeros((2*D, C), dtype='f') weights[f'b2_{k}'] = anp.zeros(2 * D, dtype='f') self.model_constants = {'layers': layers} return weights def _made_forward(self, y, ctx, layer_const, k, weights): """Single forward pass through a MADE block.""" M1, M2 = layer_const['M1'], layer_const['M2'] W1y, W1c, b1 = weights[f'W1y_{k}'], weights[f'W1c_{k}'], weights[f'b1_{k}'] W2, W2c, b2 = weights[f'W2_{k}'], weights[f'W2c_{k}'], weights[f'b2_{k}'] y_h = anp.dot(y, (W1y * M1).T) c_h = anp.dot(ctx, W1c.T) if self.feature_dim > 0 else 0.0 h = anp.tanh(y_h + c_h + b1) M2_tiled = anp.concatenate([M2, M2], axis=0) out = anp.dot(h, (W2 * M2_tiled).T) if self.feature_dim > 0: out = out + anp.dot(ctx, W2c.T) out = out + b2 mu, alpha = out[:, :self.param_dim], anp.clip(out[:, self.param_dim:], -7.0, 7.0) return mu, alpha def _get_log_prob(self, weights: dict, features: anp.ndarray, params: anp.ndarray): """Computes log probability for the MAF.""" u = params log_det = anp.zeros(params.shape[0]) for k, layer_const in enumerate(self.model_constants['layers']): u = u[:, layer_const['perm']] mu, alpha = self._made_forward(u, features, layer_const, k, weights) u = (u - mu) * anp.exp(-alpha) log_det -= anp.sum(alpha, axis=1) base_logp = -0.5 * anp.sum(u**2, axis=1) - 0.5 * self.param_dim * anp.log(2.0 * anp.pi) return base_logp + log_det def _loss_function(self, weights: dict, features: anp.ndarray, params: anp.ndarray) -> float: return -anp.mean(self._get_log_prob(weights, features, params))
[docs] def log_prob(self, features: anp.ndarray, params: anp.ndarray) -> anp.ndarray: super().log_prob(features, params) return self._get_log_prob(self.weights, features, params)
[docs] def sample(self, features: anp.ndarray, n_samples: int, rng: anp.random.RandomState) -> anp.ndarray: super().sample(features, n_samples, rng) features = features.astype('f') if features.ndim == 1: features = features.reshape(1, -1) n_cond = features.shape[0] # Broadcast features to match number of samples if n_cond != n_samples: features = anp.repeat(features, n_samples, axis=0) z = rng.randn(n_samples, self.param_dim).astype('f') x = z # Invert the flow stack for k, layer_const in reversed(list(enumerate(self.model_constants['layers']))): y_perm = x u = anp.zeros_like(y_perm) for i in range(self.param_dim): mu, alpha = self._made_forward(u, features, layer_const, k, self.weights) u[:, i] = y_perm[:, i] * anp.exp(alpha[:, i]) + mu[:, i] x = u[:, layer_const['inv_perm']] # Reshape to (n_conditions, n_samples, param_dim) return x.reshape(features.shape[0] // n_samples, n_samples, self.param_dim)