Source code for vbi.feature_extraction.utility

"""Signal-processing and array utility functions for VBI feature extraction.

Provides Numba-accelerated and pure-NumPy helpers for windowing, filtering,
downsampling, and other preprocessing steps used in the feature pipeline.
"""
import logging
import numpy as np
import pandas as pd
from typing import Union, List

# Optional torch import
try:
    import torch
    from torch import Tensor
    _TORCH_AVAILABLE = True
except ImportError:
    _TORCH_AVAILABLE = False
    # Create a dummy Tensor type for type hints
    Tensor = type(None)


[docs] def count_depth(ls): """ count the depth of a list """ if isinstance(ls, (list, tuple)): return 1 + max(count_depth(item) for item in ls) else: return 0
[docs] def prepare_input(ts, dtype=np.float32): """ prepare input format Parameters ---------- ts : array-like or list Input from which the features are extracted Returns ------- ts: nd-array formatted input """ n_trial = 0 if isinstance(ts, np.ndarray): if ts.ndim == 3: pass elif ts.ndim == 2: ts = ts[:, np.newaxis, :] # n_region = 1 else: ts = ts[np.newaxis, np.newaxis, :] # n_region , n_trial = 1 elif isinstance(ts, (list, tuple)): if isinstance(ts[0], np.ndarray): if ts[0].ndim == 2: ts = np.array(ts, dtype=dtype) elif ts[0].ndim == 1: ts = np.array(ts, dtype=dtype) ts = ts[:, np.newaxis, :] # n_region = 1 else: ts = np.array(ts, dtype=dtype)[np.newaxis, np.newaxis, :] else: if isinstance(ts[0], (list, tuple)): depth = count_depth(ts) if depth == 3: ts = np.asarray(ts) elif depth == 2: ts = np.array(ts) ts = ts[:, np.newaxis, :] # n_region = 1 else: ts = np.array(ts)[ np.newaxis, np.newaxis, : ] # n_region , n_trial = 1 # if ts is dataframe elif isinstance(ts, pd.DataFrame): # assume that the dataframe is in the form of # columns: time series # rows: time ts = ts.values.T ts = ts[:, np.newaxis, :] # n_region = 1 return ts, n_trial
[docs] def prepare_input_ts(ts, indices: List[int] = None): if not isinstance(ts, np.ndarray): ts = np.array(ts) if indices is None: indices = np.arange(ts.shape[0], dtype=np.int32) # check indices validity if not isinstance(indices, (list, tuple, np.ndarray)): raise ValueError("indices must be a list, tuple, or numpy array.") if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices): raise ValueError("indices must be a list of integers.") if not all(i < ts.shape[0] for i in indices): raise ValueError("indices must be smaller than the number of time series.") ts = ts[indices] if ts.ndim == 1: ts = ts.reshape(1, -1) if ts.size == 0: return False, ts if np.isnan(ts).any() or np.isinf(ts).any(): return False, ts return True, ts
[docs] def make_mask(n, indices): """ make a mask matrix with given indices Parameters ---------- n : int size of the mask matrix indices : list indices of the mask matrix Returns ------- mask : numpy.ndarray mask matrix """ # check validity of indices if not isinstance(indices, (list, tuple, np.ndarray)): raise ValueError("indices must be a list, tuple, or numpy array.") if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices): raise ValueError("indices must be a list of integers.") if not all(i < n for i in indices): raise ValueError("indices must be smaller than n.") mask = np.zeros((n, n), dtype=np.int64) mask[np.ix_(indices, indices)] = 1 mask = mask - np.diag(np.diag(mask)) return mask
[docs] def get_intrah_mask(n_nodes): """ Get a mask for intrahemispheric connections. Parameters ---------- n_nodes: int number of total nodes that constitute the data. Returns ------- mask_intrah: 2d array mask for intrahemispheric connections. """ row_idx = np.arange(n_nodes) idx1 = np.ix_(row_idx[: n_nodes // 2], row_idx[: n_nodes // 2]) idx2 = np.ix_(row_idx[n_nodes // 2 :], row_idx[n_nodes // 2 :]) # build on a zeros mask mask_intrah = np.zeros((n_nodes, n_nodes)) mask_intrah[idx1] = 1 mask_intrah[idx2] = 1 return mask_intrah
[docs] def get_interh_mask(n_nodes): """ Get a mask for interhemispheric connections. Parameters ---------- n_nodes: int number of total nodes that constitute the data. Returns ------- mask_interh: 2d array mask for interhemispheric connections. """ row_idx = np.arange(n_nodes // 2) col_idx1 = np.where(np.eye(n_nodes, k=-n_nodes // 2))[0] col_idx2 = np.where(np.eye(n_nodes, k=n_nodes // 2))[0] idx1 = np.ix_(row_idx, col_idx1) idx2 = np.ix_(row_idx + n_nodes // 2, col_idx2) # build on a zeros mask mask_interh = np.zeros((n_nodes, n_nodes)) mask_interh[idx1] = 1 mask_interh[idx2] = 1 return mask_interh
[docs] def get_masks(n_nodes, networks): """ Get a dictionary of masks based on the requested networks. Parameters ---------- n_nodes: int number of total nodes that constitute the data. networks: list of str list of networks to be included in the dictionary. 'full': full-network connections 'intrah': intrahemispheric connections 'interh': interhemispheric connections to get a custom mask with specific indices refere to `hbt.utility.make_mask(n, indices)`. Returns ------- masks: dict dictionary of masks based on the requested networks. """ masks = {} valid_networks = ["full", "intrah", "interh"] # check if networks are valid if not is_sequence(networks): networks = [networks] for i, ntw in enumerate(networks): if ntw not in valid_networks: raise ValueError( f"Invalid network: {ntw}. Please choose from {valid_networks}." ) if ntw == "full": masks[ntw] = np.ones((n_nodes, n_nodes)) elif ntw == "intrah": masks[ntw] = get_intrah_mask(n_nodes) elif ntw == "interh": masks[ntw] = get_interh_mask(n_nodes) return masks
[docs] def is_sequence(arg): """ Check if the input is a sequence (list, tuple, np.ndarray, etc.) Parameters ---------- arg : any input to be checked. Returns ------- bool True if the input is a sequence, False otherwise. """ return isinstance(arg, (list, tuple, np.ndarray))
[docs] def set_k_diagonals(A, k=0, value=0): """ set k diagonals of the given matrix to given value. Parameters ---------- A : numpy.ndarray input matrix. k : int number of diagonals to be set. The default is 0. Notice that the main diagonal is 0. value : int, optional value to be set. The default is 0. """ if not isinstance(A, np.ndarray): A = np.array(A) if A.ndim != 2: raise ValueError("A must be a 2d array.") if not isinstance(k, int): raise ValueError("k must be an integer.") if not isinstance(value, (int, float)): raise ValueError("value must be a number.") if k >= A.shape[0]: raise ValueError("k must be smaller than the size of A.") n = A.shape[0] for i in range(-k, k + 1): a1 = np.diag(np.random.randint(1, 2, n - abs(i)), i) idx = np.where(a1) A[idx] = value return A
[docs] def if_symmetric(A, tol=1e-8): """ Check if the input matrix is symmetric. Parameters ---------- A : numpy.ndarray input matrix. tol : float, optional tolerance for checking symmetry. The default is 1e-8. Returns ------- bool True if the input matrix is symmetric, False otherwise. """ if not isinstance(A, np.ndarray): A = np.array(A) if A.ndim != 2: raise ValueError("A must be a 2d array.") return np.allclose(A, A.T, atol=tol)
[docs] def scipy_iir_filter_data( x, sfreq, l_freq, h_freq, l_trans_bandwidth=None, h_trans_bandwidth=None, **kwargs ): """ Custom, scipy based filtering function with basic butterworth filter. #comes from neurolib Parameters ---------- x : np.ndarray data to be filtered, time is the last axis sfreq : float sampling frequency of the data in Hz l_freq : float|None frequency below which to filter the data in Hz h_freq : float|None frequency above which to filter the data in Hz l_trans_bandwidth : keeping for compatibility with mne h_trans_bandwidth : keeping for compatibility with mne **kwargs : possible keywords to `scipy.signal.butter`: Returns ------- np.ndarray filtered data """ from scipy.signal import butter, sosfiltfilt nyq = 0.5 * sfreq if l_freq is not None: low = l_freq / nyq if h_freq is not None: # so we have band filter high = h_freq / nyq if l_freq < h_freq: btype = "bandpass" elif l_freq > h_freq: btype = "bandstop" Wn = [low, high] elif h_freq is None: # so we have a high-pass filter Wn = low btype = "highpass" elif l_freq is None: # we have a low-pass high = h_freq / nyq Wn = high btype = "lowpass" # get butter coeffs sos = butter(N=kwargs.pop("order", 8), Wn=Wn, btype=btype, output="sos") return sosfiltfilt(sos, x, axis=-1)
[docs] def filter( ts: np.ndarray, fs: float, low_freq: float, high_freq: float, l_trans_bandwidth: str = "auto", h_trans_bandwidth: str = "auto", **kwargs, ): """ Filter data. Can be: - low-pass (low_freq is None, high_freq is not None), - high-pass (high_freq is None, low_freq is not None), - band-pass (l_freq < h_freq), - band-stop (l_freq > h_freq) filter type Parameters ---------- ts: np.ndarray Time series data low_freq : float|None frequency below which to filter the data. high_freq : float|None frequency above which to filter the data. l_trans_bandwidth : float|str transition band width for low frequency h_trans_bandwidth : float|str transition band width for high frequency inplace : bool whether to do the operation in place or return kwargs : possible keywords to mne.filter.create_filter: filter_length="auto", method="fir", iir_params=None phase="zero", fir_window="hamming", fir_design="firwin" Returns ------- np.ndarray filtered data """ try: from mne.filter import filter_data except ImportError: logging.warning( "`mne` module not found, falling back to basic scipy's function" ) filter_data = scipy_iir_filter_data filtered = filter_data( ts, # times has to be the last axis sfreq=fs, l_freq=low_freq, h_freq=high_freq, l_trans_bandwidth=l_trans_bandwidth, h_trans_bandwidth=h_trans_bandwidth, **kwargs, ) return filtered
[docs] def posterior_shrinkage( prior_samples: Union[Tensor, np.ndarray], post_samples: Union[Tensor, np.ndarray] ) -> Tensor: """ Calculate the posterior shrinkage, quantifying how much the posterior distribution contracts from the initial prior distribution. References: https://arxiv.org/abs/1803.08393 Parameters ---------- prior_samples : array_like or torch.Tensor [n_samples, n_params] Samples from the prior distribution. post_samples : array-like or torch.Tensor [n_samples, n_params] Samples from the posterior distribution. Returns ------- shrinkage : torch.Tensor [n_params] The posterior shrinkage. """ if len(prior_samples) == 0 or len(post_samples) == 0: raise ValueError("Input samples are empty") if not isinstance(prior_samples, torch.Tensor): prior_samples = torch.tensor(prior_samples, dtype=torch.float32) if not isinstance(post_samples, torch.Tensor): post_samples = torch.tensor(post_samples, dtype=torch.float32) if prior_samples.ndim == 1: prior_samples = prior_samples[:, None] if post_samples.ndim == 1: post_samples = post_samples[:, None] prior_std = torch.std(prior_samples, dim=0) post_std = torch.std(post_samples, dim=0) return 1 - (post_std / prior_std) ** 2
[docs] def posterior_zscore( true_theta: Union[Tensor, np.array, float], post_samples: Union[Tensor, np.array] ): """ Calculate the posterior z-score, quantifying how much the posterior distribution of a parameter encompasses its true value. References: https://arxiv.org/abs/1803.08393 Parameters ---------- true_theta : float, array-like or torch.Tensor [n_params] The true value of the parameters. post_samples : array-like or torch.Tensor [n_samples, n_params] Samples from the posterior distributions. Returns ------- z : Tensor [n_params] The z-score of the posterior distributions. """ if len(post_samples) == 0: raise ValueError("Input samples are empty") if not isinstance(true_theta, torch.Tensor): true_theta = torch.tensor(true_theta, dtype=torch.float32) if not isinstance(post_samples, torch.Tensor): post_samples = torch.tensor(post_samples, dtype=torch.float32) true_theta = np.atleast_1d(true_theta) if post_samples.ndim == 1: post_samples = post_samples[:, None] post_mean = torch.mean(post_samples, dim=0) post_std = torch.std(post_samples, dim=0) return torch.abs((post_mean - true_theta) / post_std)