import logging
import numpy as np
import pandas as pd
from typing import Union, List
# Optional torch import
try:
import torch
from torch import Tensor
_TORCH_AVAILABLE = True
except ImportError:
_TORCH_AVAILABLE = False
# Create a dummy Tensor type for type hints
Tensor = type(None)
[docs]
def count_depth(ls):
"""
count the depth of a list
"""
if isinstance(ls, (list, tuple)):
return 1 + max(count_depth(item) for item in ls)
else:
return 0
[docs]
def make_mask(n, indices):
"""
make a mask matrix with given indices
Parameters
----------
n : int
size of the mask matrix
indices : list
indices of the mask matrix
Returns
-------
mask : numpy.ndarray
mask matrix
"""
# check validity of indices
if not isinstance(indices, (list, tuple, np.ndarray)):
raise ValueError("indices must be a list, tuple, or numpy array.")
if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices):
raise ValueError("indices must be a list of integers.")
if not all(i < n for i in indices):
raise ValueError("indices must be smaller than n.")
mask = np.zeros((n, n), dtype=np.int64)
mask[np.ix_(indices, indices)] = 1
mask = mask - np.diag(np.diag(mask))
return mask
[docs]
def get_intrah_mask(n_nodes):
"""
Get a mask for intrahemispheric connections.
Parameters
----------
n_nodes: int
number of total nodes that constitute the data.
Returns
-------
mask_intrah: 2d array
mask for intrahemispheric connections.
"""
row_idx = np.arange(n_nodes)
idx1 = np.ix_(row_idx[: n_nodes // 2], row_idx[: n_nodes // 2])
idx2 = np.ix_(row_idx[n_nodes // 2 :], row_idx[n_nodes // 2 :])
# build on a zeros mask
mask_intrah = np.zeros((n_nodes, n_nodes))
mask_intrah[idx1] = 1
mask_intrah[idx2] = 1
return mask_intrah
[docs]
def get_interh_mask(n_nodes):
"""
Get a mask for interhemispheric connections.
Parameters
----------
n_nodes: int
number of total nodes that constitute the data.
Returns
-------
mask_interh: 2d array
mask for interhemispheric connections.
"""
row_idx = np.arange(n_nodes // 2)
col_idx1 = np.where(np.eye(n_nodes, k=-n_nodes // 2))[0]
col_idx2 = np.where(np.eye(n_nodes, k=n_nodes // 2))[0]
idx1 = np.ix_(row_idx, col_idx1)
idx2 = np.ix_(row_idx + n_nodes // 2, col_idx2)
# build on a zeros mask
mask_interh = np.zeros((n_nodes, n_nodes))
mask_interh[idx1] = 1
mask_interh[idx2] = 1
return mask_interh
[docs]
def get_masks(n_nodes, networks):
"""
Get a dictionary of masks based on the requested networks.
Parameters
----------
n_nodes: int
number of total nodes that constitute the data.
networks: list of str
list of networks to be included in the dictionary.
'full': full-network connections
'intrah': intrahemispheric connections
'interh': interhemispheric connections
to get a custom mask with specific indices
refere to `hbt.utility.make_mask(n, indices)`.
Returns
-------
masks: dict
dictionary of masks based on the requested networks.
"""
masks = {}
valid_networks = ["full", "intrah", "interh"]
# check if networks are valid
if not is_sequence(networks):
networks = [networks]
for i, ntw in enumerate(networks):
if ntw not in valid_networks:
raise ValueError(
f"Invalid network: {ntw}. Please choose from {valid_networks}."
)
if ntw == "full":
masks[ntw] = np.ones((n_nodes, n_nodes))
elif ntw == "intrah":
masks[ntw] = get_intrah_mask(n_nodes)
elif ntw == "interh":
masks[ntw] = get_interh_mask(n_nodes)
return masks
[docs]
def is_sequence(arg):
"""
Check if the input is a sequence (list, tuple, np.ndarray, etc.)
Parameters
----------
arg : any
input to be checked.
Returns
-------
bool
True if the input is a sequence, False otherwise.
"""
return isinstance(arg, (list, tuple, np.ndarray))
[docs]
def set_k_diagonals(A, k=0, value=0):
"""
set k diagonals of the given matrix to given value.
Parameters
----------
A : numpy.ndarray
input matrix.
k : int
number of diagonals to be set. The default is 0.
Notice that the main diagonal is 0.
value : int, optional
value to be set. The default is 0.
"""
if not isinstance(A, np.ndarray):
A = np.array(A)
if A.ndim != 2:
raise ValueError("A must be a 2d array.")
if not isinstance(k, int):
raise ValueError("k must be an integer.")
if not isinstance(value, (int, float)):
raise ValueError("value must be a number.")
if k >= A.shape[0]:
raise ValueError("k must be smaller than the size of A.")
n = A.shape[0]
for i in range(-k, k + 1):
a1 = np.diag(np.random.randint(1, 2, n - abs(i)), i)
idx = np.where(a1)
A[idx] = value
return A
[docs]
def if_symmetric(A, tol=1e-8):
"""
Check if the input matrix is symmetric.
Parameters
----------
A : numpy.ndarray
input matrix.
tol : float, optional
tolerance for checking symmetry. The default is 1e-8.
Returns
-------
bool
True if the input matrix is symmetric, False otherwise.
"""
if not isinstance(A, np.ndarray):
A = np.array(A)
if A.ndim != 2:
raise ValueError("A must be a 2d array.")
return np.allclose(A, A.T, atol=tol)
[docs]
def scipy_iir_filter_data(
x, sfreq, l_freq, h_freq, l_trans_bandwidth=None, h_trans_bandwidth=None, **kwargs
):
"""
Custom, scipy based filtering function with basic butterworth filter.
#comes from neurolib
Parameters
----------
x : np.ndarray
data to be filtered, time is the last axis
sfreq : float
sampling frequency of the data in Hz
l_freq : float|None
frequency below which to filter the data in Hz
h_freq : float|None
frequency above which to filter the data in Hz
l_trans_bandwidth : keeping for compatibility with mne
h_trans_bandwidth : keeping for compatibility with mne
**kwargs : possible keywords to `scipy.signal.butter`:
Returns
-------
np.ndarray
filtered data
"""
from scipy.signal import butter, sosfiltfilt
nyq = 0.5 * sfreq
if l_freq is not None:
low = l_freq / nyq
if h_freq is not None:
# so we have band filter
high = h_freq / nyq
if l_freq < h_freq:
btype = "bandpass"
elif l_freq > h_freq:
btype = "bandstop"
Wn = [low, high]
elif h_freq is None:
# so we have a high-pass filter
Wn = low
btype = "highpass"
elif l_freq is None:
# we have a low-pass
high = h_freq / nyq
Wn = high
btype = "lowpass"
# get butter coeffs
sos = butter(N=kwargs.pop("order", 8), Wn=Wn, btype=btype, output="sos")
return sosfiltfilt(sos, x, axis=-1)
[docs]
def filter(
ts: np.ndarray,
fs: float,
low_freq: float,
high_freq: float,
l_trans_bandwidth: str = "auto",
h_trans_bandwidth: str = "auto",
**kwargs,
):
"""
Filter data. Can be:
- low-pass (low_freq is None, high_freq is not None),
- high-pass (high_freq is None, low_freq is not None),
- band-pass (l_freq < h_freq),
- band-stop (l_freq > h_freq) filter type
Parameters
----------
ts: np.ndarray
Time series data
low_freq : float|None
frequency below which to filter the data.
high_freq : float|None
frequency above which to filter the data.
l_trans_bandwidth : float|str
transition band width for low frequency
h_trans_bandwidth : float|str
transition band width for high frequency
inplace : bool
whether to do the operation in place or return
kwargs : possible keywords to mne.filter.create_filter:
filter_length="auto",
method="fir",
iir_params=None
phase="zero",
fir_window="hamming",
fir_design="firwin"
Returns
-------
np.ndarray
filtered data
"""
try:
from mne.filter import filter_data
except ImportError:
logging.warning(
"`mne` module not found, falling back to basic scipy's function"
)
filter_data = scipy_iir_filter_data
filtered = filter_data(
ts, # times has to be the last axis
sfreq=fs,
l_freq=low_freq,
h_freq=high_freq,
l_trans_bandwidth=l_trans_bandwidth,
h_trans_bandwidth=h_trans_bandwidth,
**kwargs,
)
return filtered
[docs]
def posterior_shrinkage(
prior_samples: Union[Tensor, np.ndarray], post_samples: Union[Tensor, np.ndarray]
) -> Tensor:
"""
Calculate the posterior shrinkage, quantifying how much
the posterior distribution contracts from the initial
prior distribution.
References:
https://arxiv.org/abs/1803.08393
Parameters
----------
prior_samples : array_like or torch.Tensor [n_samples, n_params]
Samples from the prior distribution.
post_samples : array-like or torch.Tensor [n_samples, n_params]
Samples from the posterior distribution.
Returns
-------
shrinkage : torch.Tensor [n_params]
The posterior shrinkage.
"""
if len(prior_samples) == 0 or len(post_samples) == 0:
raise ValueError("Input samples are empty")
if not isinstance(prior_samples, torch.Tensor):
prior_samples = torch.tensor(prior_samples, dtype=torch.float32)
if not isinstance(post_samples, torch.Tensor):
post_samples = torch.tensor(post_samples, dtype=torch.float32)
if prior_samples.ndim == 1:
prior_samples = prior_samples[:, None]
if post_samples.ndim == 1:
post_samples = post_samples[:, None]
prior_std = torch.std(prior_samples, dim=0)
post_std = torch.std(post_samples, dim=0)
return 1 - (post_std / prior_std) ** 2
[docs]
def posterior_zscore(
true_theta: Union[Tensor, np.array, float], post_samples: Union[Tensor, np.array]
):
"""
Calculate the posterior z-score, quantifying how much the posterior
distribution of a parameter encompasses its true value.
References:
https://arxiv.org/abs/1803.08393
Parameters
----------
true_theta : float, array-like or torch.Tensor [n_params]
The true value of the parameters.
post_samples : array-like or torch.Tensor [n_samples, n_params]
Samples from the posterior distributions.
Returns
-------
z : Tensor [n_params]
The z-score of the posterior distributions.
"""
if len(post_samples) == 0:
raise ValueError("Input samples are empty")
if not isinstance(true_theta, torch.Tensor):
true_theta = torch.tensor(true_theta, dtype=torch.float32)
if not isinstance(post_samples, torch.Tensor):
post_samples = torch.tensor(post_samples, dtype=torch.float32)
true_theta = np.atleast_1d(true_theta)
if post_samples.ndim == 1:
post_samples = post_samples[:, None]
post_mean = torch.mean(post_samples, dim=0)
post_std = torch.std(post_samples, dim=0)
return torch.abs((post_mean - true_theta) / post_std)