"""Signal-processing and array utility functions for VBI feature extraction.
Provides Numba-accelerated and pure-NumPy helpers for windowing, filtering,
downsampling, and other preprocessing steps used in the feature pipeline.
"""
import logging
import numpy as np
import pandas as pd
from typing import Union, List
# Optional torch import
try:
import torch
from torch import Tensor
_TORCH_AVAILABLE = True
except ImportError:
_TORCH_AVAILABLE = False
# Create a dummy Tensor type for type hints
Tensor = type(None)
[docs]
def count_depth(ls):
"""
count the depth of a list
"""
if isinstance(ls, (list, tuple)):
return 1 + max(count_depth(item) for item in ls)
else:
return 0
[docs]
def make_mask(n, indices):
"""
make a mask matrix with given indices
Parameters
----------
n : int
size of the mask matrix
indices : list
indices of the mask matrix
Returns
-------
mask : numpy.ndarray
mask matrix
"""
# check validity of indices
if not isinstance(indices, (list, tuple, np.ndarray)):
raise ValueError("indices must be a list, tuple, or numpy array.")
if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices):
raise ValueError("indices must be a list of integers.")
if not all(i < n for i in indices):
raise ValueError("indices must be smaller than n.")
mask = np.zeros((n, n), dtype=np.int64)
mask[np.ix_(indices, indices)] = 1
mask = mask - np.diag(np.diag(mask))
return mask
[docs]
def get_intrah_mask(n_nodes):
"""
Get a mask for intrahemispheric connections.
Parameters
----------
n_nodes: int
number of total nodes that constitute the data.
Returns
-------
mask_intrah: 2d array
mask for intrahemispheric connections.
"""
row_idx = np.arange(n_nodes)
idx1 = np.ix_(row_idx[: n_nodes // 2], row_idx[: n_nodes // 2])
idx2 = np.ix_(row_idx[n_nodes // 2 :], row_idx[n_nodes // 2 :])
# build on a zeros mask
mask_intrah = np.zeros((n_nodes, n_nodes))
mask_intrah[idx1] = 1
mask_intrah[idx2] = 1
return mask_intrah
[docs]
def get_interh_mask(n_nodes):
"""
Get a mask for interhemispheric connections.
Parameters
----------
n_nodes: int
number of total nodes that constitute the data.
Returns
-------
mask_interh: 2d array
mask for interhemispheric connections.
"""
row_idx = np.arange(n_nodes // 2)
col_idx1 = np.where(np.eye(n_nodes, k=-n_nodes // 2))[0]
col_idx2 = np.where(np.eye(n_nodes, k=n_nodes // 2))[0]
idx1 = np.ix_(row_idx, col_idx1)
idx2 = np.ix_(row_idx + n_nodes // 2, col_idx2)
# build on a zeros mask
mask_interh = np.zeros((n_nodes, n_nodes))
mask_interh[idx1] = 1
mask_interh[idx2] = 1
return mask_interh
[docs]
def get_masks(n_nodes, networks):
"""
Get a dictionary of masks based on the requested networks.
Parameters
----------
n_nodes: int
number of total nodes that constitute the data.
networks: list of str
list of networks to be included in the dictionary.
'full': full-network connections
'intrah': intrahemispheric connections
'interh': interhemispheric connections
to get a custom mask with specific indices
refere to `hbt.utility.make_mask(n, indices)`.
Returns
-------
masks: dict
dictionary of masks based on the requested networks.
"""
masks = {}
valid_networks = ["full", "intrah", "interh"]
# check if networks are valid
if not is_sequence(networks):
networks = [networks]
for i, ntw in enumerate(networks):
if ntw not in valid_networks:
raise ValueError(
f"Invalid network: {ntw}. Please choose from {valid_networks}."
)
if ntw == "full":
masks[ntw] = np.ones((n_nodes, n_nodes))
elif ntw == "intrah":
masks[ntw] = get_intrah_mask(n_nodes)
elif ntw == "interh":
masks[ntw] = get_interh_mask(n_nodes)
return masks
[docs]
def is_sequence(arg):
"""
Check if the input is a sequence (list, tuple, np.ndarray, etc.)
Parameters
----------
arg : any
input to be checked.
Returns
-------
bool
True if the input is a sequence, False otherwise.
"""
return isinstance(arg, (list, tuple, np.ndarray))
[docs]
def set_k_diagonals(A, k=0, value=0):
"""
set k diagonals of the given matrix to given value.
Parameters
----------
A : numpy.ndarray
input matrix.
k : int
number of diagonals to be set. The default is 0.
Notice that the main diagonal is 0.
value : int, optional
value to be set. The default is 0.
"""
if not isinstance(A, np.ndarray):
A = np.array(A)
if A.ndim != 2:
raise ValueError("A must be a 2d array.")
if not isinstance(k, int):
raise ValueError("k must be an integer.")
if not isinstance(value, (int, float)):
raise ValueError("value must be a number.")
if k >= A.shape[0]:
raise ValueError("k must be smaller than the size of A.")
n = A.shape[0]
for i in range(-k, k + 1):
a1 = np.diag(np.random.randint(1, 2, n - abs(i)), i)
idx = np.where(a1)
A[idx] = value
return A
[docs]
def if_symmetric(A, tol=1e-8):
"""
Check if the input matrix is symmetric.
Parameters
----------
A : numpy.ndarray
input matrix.
tol : float, optional
tolerance for checking symmetry. The default is 1e-8.
Returns
-------
bool
True if the input matrix is symmetric, False otherwise.
"""
if not isinstance(A, np.ndarray):
A = np.array(A)
if A.ndim != 2:
raise ValueError("A must be a 2d array.")
return np.allclose(A, A.T, atol=tol)
[docs]
def scipy_iir_filter_data(
x, sfreq, l_freq, h_freq, l_trans_bandwidth=None, h_trans_bandwidth=None, **kwargs
):
"""
Custom, scipy based filtering function with basic butterworth filter.
#comes from neurolib
Parameters
----------
x : np.ndarray
data to be filtered, time is the last axis
sfreq : float
sampling frequency of the data in Hz
l_freq : float|None
frequency below which to filter the data in Hz
h_freq : float|None
frequency above which to filter the data in Hz
l_trans_bandwidth : keeping for compatibility with mne
h_trans_bandwidth : keeping for compatibility with mne
**kwargs : possible keywords to `scipy.signal.butter`:
Returns
-------
np.ndarray
filtered data
"""
from scipy.signal import butter, sosfiltfilt
nyq = 0.5 * sfreq
if l_freq is not None:
low = l_freq / nyq
if h_freq is not None:
# so we have band filter
high = h_freq / nyq
if l_freq < h_freq:
btype = "bandpass"
elif l_freq > h_freq:
btype = "bandstop"
Wn = [low, high]
elif h_freq is None:
# so we have a high-pass filter
Wn = low
btype = "highpass"
elif l_freq is None:
# we have a low-pass
high = h_freq / nyq
Wn = high
btype = "lowpass"
# get butter coeffs
sos = butter(N=kwargs.pop("order", 8), Wn=Wn, btype=btype, output="sos")
return sosfiltfilt(sos, x, axis=-1)
[docs]
def filter(
ts: np.ndarray,
fs: float,
low_freq: float,
high_freq: float,
l_trans_bandwidth: str = "auto",
h_trans_bandwidth: str = "auto",
**kwargs,
):
"""
Filter data. Can be:
- low-pass (low_freq is None, high_freq is not None),
- high-pass (high_freq is None, low_freq is not None),
- band-pass (l_freq < h_freq),
- band-stop (l_freq > h_freq) filter type
Parameters
----------
ts: np.ndarray
Time series data
low_freq : float|None
frequency below which to filter the data.
high_freq : float|None
frequency above which to filter the data.
l_trans_bandwidth : float|str
transition band width for low frequency
h_trans_bandwidth : float|str
transition band width for high frequency
inplace : bool
whether to do the operation in place or return
kwargs : possible keywords to mne.filter.create_filter:
filter_length="auto",
method="fir",
iir_params=None
phase="zero",
fir_window="hamming",
fir_design="firwin"
Returns
-------
np.ndarray
filtered data
"""
try:
from mne.filter import filter_data
except ImportError:
logging.warning(
"`mne` module not found, falling back to basic scipy's function"
)
filter_data = scipy_iir_filter_data
filtered = filter_data(
ts, # times has to be the last axis
sfreq=fs,
l_freq=low_freq,
h_freq=high_freq,
l_trans_bandwidth=l_trans_bandwidth,
h_trans_bandwidth=h_trans_bandwidth,
**kwargs,
)
return filtered
[docs]
def posterior_shrinkage(
prior_samples: Union[Tensor, np.ndarray], post_samples: Union[Tensor, np.ndarray]
) -> Tensor:
"""
Calculate the posterior shrinkage, quantifying how much
the posterior distribution contracts from the initial
prior distribution.
References:
https://arxiv.org/abs/1803.08393
Parameters
----------
prior_samples : array_like or torch.Tensor [n_samples, n_params]
Samples from the prior distribution.
post_samples : array-like or torch.Tensor [n_samples, n_params]
Samples from the posterior distribution.
Returns
-------
shrinkage : torch.Tensor [n_params]
The posterior shrinkage.
"""
if len(prior_samples) == 0 or len(post_samples) == 0:
raise ValueError("Input samples are empty")
if not isinstance(prior_samples, torch.Tensor):
prior_samples = torch.tensor(prior_samples, dtype=torch.float32)
if not isinstance(post_samples, torch.Tensor):
post_samples = torch.tensor(post_samples, dtype=torch.float32)
if prior_samples.ndim == 1:
prior_samples = prior_samples[:, None]
if post_samples.ndim == 1:
post_samples = post_samples[:, None]
prior_std = torch.std(prior_samples, dim=0)
post_std = torch.std(post_samples, dim=0)
return 1 - (post_std / prior_std) ** 2
[docs]
def posterior_zscore(
true_theta: Union[Tensor, np.array, float], post_samples: Union[Tensor, np.array]
):
"""
Calculate the posterior z-score, quantifying how much the posterior
distribution of a parameter encompasses its true value.
References:
https://arxiv.org/abs/1803.08393
Parameters
----------
true_theta : float, array-like or torch.Tensor [n_params]
The true value of the parameters.
post_samples : array-like or torch.Tensor [n_samples, n_params]
Samples from the posterior distributions.
Returns
-------
z : Tensor [n_params]
The z-score of the posterior distributions.
"""
if len(post_samples) == 0:
raise ValueError("Input samples are empty")
if not isinstance(true_theta, torch.Tensor):
true_theta = torch.tensor(true_theta, dtype=torch.float32)
if not isinstance(post_samples, torch.Tensor):
post_samples = torch.tensor(post_samples, dtype=torch.float32)
true_theta = np.atleast_1d(true_theta)
if post_samples.ndim == 1:
post_samples = post_samples[:, None]
post_mean = torch.mean(post_samples, dim=0)
post_std = torch.std(post_samples, dim=0)
return torch.abs((post_mean - true_theta) / post_std)