Source code for tonic.audio_transforms

from dataclasses import dataclass
from typing import Iterator, List, Tuple, Union

import librosa
import numpy as np
from scipy.signal import butter, sosfilt


[docs]@dataclass class FixLength: """Fix the length of a sample along a specified axis to a given length. Parameters: length (int): Desired length of the sample axis (int, optional): Dimension along which the length needs to be fixed.. Defaults to 1. Args: data (np.ndarray): data sample Returns: np.ndarray: fixed length data sample """ length: int axis: int = 1
[docs] def __call__(self, data: np.ndarray) -> np.ndarray: return librosa.util.fix_length(data=data, size=self.length, axis=self.axis)
[docs]@dataclass class Bin: """Bin the given data along a specified axis at the specified new frequency. Parameters: orig_freq (float): Sampling frequency of the given data stream new_freq (float): Desired frequency after binning axis (int): Axis along which the data needs to be binned Args: data (np.ndarray): data sample Returns: np.ndarray: binned data sample """ orig_freq: float new_freq: float axis: int
[docs] def __call__(self, data: np.ndarray) -> np.ndarray: data_len = data.shape[self.axis] n_splits = int(data_len / (self.orig_freq / self.new_freq)) splits = np.array_split(data, n_splits, axis=self.axis) data = [np.sum(split, axis=self.axis, keepdims=True) for split in splits] return np.concatenate(data, self.axis)
[docs]@dataclass class SOSFilter: """SOS filter. Parameters ---------- coeffs: coefficients of the second order filter axis: Axis along with the filter needs to be applied See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfilt.html for more details """ coeffs: np.ndarray axis: int
[docs] def __call__(self, signal): return sosfilt(self.coeffs, signal, axis=self.axis)
[docs]@dataclass class ButterFilter: """Butter filter. Parameters ---------- order: Order of filter to be used freq: Frequency for the filter (float or (float, float)) analog: True if analog filter btype: Filter type, {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’} rectify: If true, the output is the absolute value of the filtered output axis: Axis along which the filter needs to be applied See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html#scipy.signal.butter for more details on parameters. """ order: int freq: Union[float, Tuple[float, float]] analog: bool btype: str rectify: bool axis: int
[docs] def __post_init__(self): coeffs = butter( self.order, self.freq, analog=self.analog, btype=self.btype, output="sos" ) self.filter = SOSFilter(coeffs, axis=self.axis)
[docs] def __call__(self, data: np.ndarray) -> np.ndarray: out = self.filter(data) if self.rectify: return np.abs(out) else: return out
[docs]@dataclass class ButterFilterBank: """Butter filter bank. Parameters ---------- order: Order of filter to be used freq: Frequency for the filter (float or (float, float)) rectify: If true, the output is the absolute value of the filtered output axis: Axis along which the filter needs to be applied analog: If true, the filter will be analog. False by default See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html#scipy.signal.butter for more details on parameters. """ order: int freq: List[Tuple[float, float]] rectify: bool axis: int analog: bool = False
[docs] def __post_init__(self): self.filters = [ ButterFilter( self.order, freq, analog=self.analog, btype="band", rectify=self.rectify, axis=self.axis, ) for freq in self.freq ]
[docs] def __call__(self, data): return np.concatenate([filt(data) for filt in self.filters], axis=0)
[docs]@dataclass class LinearButterFilterBank: """Butter filter bank. Parameters ---------- order: Order of filter to be used low_freq: Lower/cutoff frequency the filter (float or (float, float)) sampling_freq: Sampling frequency of the signal, also serves as higher frequency of the filter bank. analog: True if analog filter rectify: If true, the output is the absolute value of the filtered output axis: Axis along which the filter needs to be applied See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html#scipy.signal.butter for more details on parameters. """ order: int = 2 low_freq: float = 100 sampling_freq: float = 16000 analog: bool = False num_filters: int = 64 rectify: bool = True axis: int = -1
[docs] def compute_freq_bands(self): filter_bandwidth = 2 / self.num_filters nyquist = self.sampling_freq / 2 high_freq = self.sampling_freq / 2 / (1 + filter_bandwidth) - 1 freqs = np.linspace(self.low_freq, high_freq, self.num_filters) return np.array([freqs, freqs * (1 + filter_bandwidth)]).T / nyquist
[docs] def __post_init__(self): freq_bands = self.compute_freq_bands() self.filterbank = ButterFilterBank( order=self.order, freq=freq_bands, rectify=self.rectify, axis=self.axis )
[docs] def __call__(self, data): return self.filterbank(data)
[docs]@dataclass class MelButterFilterBank(LinearButterFilterBank): """Butter filter bank with frequencies along the mel scale. Parameters ---------- order: Order of filter to be used low_freq: Lower/cutoff frequency the filter (float or (float, float)) sampling_freq: Sampling frequency of the signal, also serves as higher frequency of the filter bank. analog: True if analog filter rectify: If true, the output is the absolute value of the filtered output axis: Axis along which the filter needs to be applied See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html#scipy.signal.butter for more details on parameters. """
[docs] @staticmethod def hz2mel(freq): return 2595 * np.log10(1 + freq / 700)
[docs] @staticmethod def mel2hz(freq): return 700 * (10 ** (freq / 2595) - 1)
[docs] def compute_freq_bands(self): filter_bandwidth = 2 / self.num_filters nyquist = self.sampling_freq / 2 high_freq = self.sampling_freq / 2 / (1 + filter_bandwidth) - 1 freqs = np.linspace(self.low_freq, high_freq, self.num_filters) freq_bands = np.array([freqs, freqs * (1 + filter_bandwidth)]) / nyquist low_freq = self.hz2mel(self.low_freq) high_freq = self.hz2mel(self.sampling_freq / 2 / (1 + filter_bandwidth) - 1) freqs = self.mel2hz(np.linspace(low_freq, high_freq, self.num_filters)) return np.array([freqs, freqs * (1 + filter_bandwidth)]).T / nyquist
[docs]@dataclass class AddNoise: """Add nose to data. Params: dataset: A dataset object that returns a tuple when iterated over the first element of which is the audio signal to be used for noise. snr: Desired signal to noise ratio in dB normed: If set to false, the signal max value will not be normalized. True by default. """ dataset: Iterator snr: float normed: bool = True
[docs] def get_noise_sample(self, sample_len: int) -> np.ndarray: """Get a random noise sample from the dataset.""" # Find noise sample of minimum length while True: noise_idx = np.random.randint(0, len(self.dataset), (1,)).item() noise = self.dataset[noise_idx][0] if noise.shape[1] >= sample_len: break # Sample a random part of the data recording noise_signal_len = noise.shape[1] if noise_signal_len > sample_len: start_t = np.random.randint(0, noise_signal_len - sample_len, (1,)).item() noise = noise[:, start_t : start_t + sample_len] return noise
[docs] def __call__(self, signal): # randomly pick a piece of noise data noise = self.get_noise_sample(sample_len=signal.shape[1]) # mix signal with noise with given SNR signal_power = (signal**2).mean() noise_power = (noise**2).mean() noise_scale = (signal_power / noise_power) * 10 ** (-self.snr / 10) signal_with_snr = signal + noise_scale * noise # Normalize if specified if self.normed: return normalize(signal_with_snr) else: return signal_with_snr
[docs]def normalize(signal): """Normalize the signal.""" signal -= signal.mean() max_val = np.max(np.abs(signal)) if max_val > 0: signal /= max_val return signal
# @dataclass # class DivisiveNormalization: # frame_dt: float # Frame clock step # num_frames_avg: int # Number of frames to average over # gating_clock_dt: float # Clock frequency of gating E(t) # dt: float = 1 # Global clock step # # def __call__(self, events: np.ndarray): # raise NotImplementedError