Source code for tonic.slicers

from dataclasses import dataclass
from typing import Any, List, Tuple

import numpy as np
from typing_extensions import Protocol, runtime_checkable


[docs]@runtime_checkable class Slicer(Protocol): """Base protocol class for slicers in Tonic. That means that you don't have to directly inherit from it, but just implement its methods. """
[docs] def get_slice_metadata(self, data: Any, targets: Any) -> List[Tuple[Any]]: """This method returns the metadata for each recording that helps with slicing, for example the indices or timestamps at which the data would be sliced. The return value is typically a list of tuples that contain start and stop information for each slice. Parameters: data: Normally a tuple of data pieces. target: Normally a tuple of target pieces. Returns: metadata as a list of tuples of start and end indices, timestamps, etc. """ ...
[docs] def slice_with_metadata(self, data: Any, targets: Any, metadata: Any) -> List[Any]: """Given a piece of data and/or targets, cut out a certain part of it based on the start/end information given in metadata. Parameters: data: Normally a tuple of data pieces. target: Normally a tuple of target pieces. metadata: An array that contains start and stop information about one slice. Returns: A subset of the original data/targets which is a slice. """ ...
[docs] def slice(self, data: Any, targets: Any) -> List[Any]: """Generate metadata and return all slices at once. Parameters: data: Normally a tuple of data pieces. target: Normally a tuple of target pieces. Returns: The whole data and targets sliced into smaller slices. """ ...
[docs]@dataclass(frozen=True) class SliceByTime: """Slices an event array along fixed time window and overlap size. The number of bins depends on the length of the recording. Targets are copied. > <overlap> >| window1 | > | window2 | Parameters: time_window (int): time for window length (same unit as event timestamps) overlap (int): overlap (same unit as event timestamps) include_incomplete (bool): include the last incomplete slice that has shorter time """ time_window: float overlap: float = 0.0 include_incomplete: bool = False
[docs] def slice(self, data: np.ndarray, targets: int) -> List[np.ndarray]: metadata = self.get_slice_metadata(data, targets) return self.slice_with_metadata(data, targets, metadata)
[docs] def get_slice_metadata( self, data: np.ndarray, targets: int ) -> List[Tuple[int, int]]: t = data["t"] stride = self.time_window - self.overlap assert stride > 0 if self.include_incomplete: n_slices = int(np.ceil(((t[-1] - t[0]) - self.time_window) / stride) + 1) else: n_slices = int(np.floor(((t[-1] - t[0]) - self.time_window) / stride) + 1) n_slices = max(n_slices, 1) # for strides larger than recording time window_start_times = np.arange(n_slices) * stride + t[0] window_end_times = window_start_times + self.time_window indices_start = np.searchsorted(t, window_start_times)[:n_slices] indices_end = np.searchsorted(t, window_end_times)[:n_slices] return list(zip(indices_start, indices_end))
[docs] @staticmethod def slice_with_metadata( data: np.ndarray, targets: int, metadata: List[Tuple[int, int]] ): return [data[start:end] for start, end in metadata], targets
[docs]@dataclass(frozen=True) class SliceByTimeBins: """ Slices data and targets along fixed number of bins of time length time_duration / bin_count * (1 + overlap). This method is good if your recordings all have roughly the same time length and you want an equal number of bins for each recording. Targets are copied. Parameters: bin_count (int): number of bins overlap (float): overlap specified as a proportion of a bin, needs to be smaller than 1. An overlap of 0.1 signifies that the bin will be enlarged by 10%. Amount of bins stays the same. """ bin_count: int overlap: float = 0
[docs] def slice(self, data: np.ndarray, targets: int) -> List[np.ndarray]: metadata = self.get_slice_metadata(data, targets) return self.slice_with_metadata(data, targets, metadata)
[docs] def get_slice_metadata( self, data: np.ndarray, targets: int ) -> List[Tuple[int, int]]: events = data assert "t" in events.dtype.names assert self.overlap < 1 times = events["t"] time_window = (times[-1] - times[0]) // self.bin_count * (1 + self.overlap) stride = time_window * (1 - self.overlap) window_start_times = np.arange(self.bin_count) * stride + times[0] window_end_times = window_start_times + time_window indices_start = np.searchsorted(times, window_start_times) indices_end = np.searchsorted(times, window_end_times) return list(zip(indices_start, indices_end))
[docs] @staticmethod def slice_with_metadata( data: np.ndarray, targets: int, metadata: List[Tuple[int, int]] ): return [data[start:end] for start, end in metadata], targets
[docs]@dataclass(frozen=True) class SliceByEventCount: """Slices data and targets along a fixed number of events and overlap size. The number of bins depends on the amount of events in the recording. Targets are copied. Parameters: event_count (int): number of events for each bin overlap (int): overlap in number of events include_incomplete (bool): include the last incomplete slice that has fewer events """ event_count: int overlap: int = 0 include_incomplete: bool = False
[docs] def slice(self, data: np.ndarray, targets: int) -> List[np.ndarray]: metadata = self.get_slice_metadata(data, targets) return self.slice_with_metadata(data, targets, metadata)
[docs] def get_slice_metadata( self, data: np.ndarray, targets: int ) -> List[Tuple[int, int]]: n_events = len(data) event_count = min(self.event_count, n_events) stride = self.event_count - self.overlap if stride <= 0: raise Exception("Inferred stride <= 0") if self.include_incomplete: n_slices = int(np.ceil((n_events - event_count) / stride) + 1) else: n_slices = int(np.floor((n_events - event_count) / stride) + 1) indices_start = (np.arange(n_slices) * stride).astype(int) indices_end = indices_start + event_count return list(zip(indices_start, indices_end))
[docs] @staticmethod def slice_with_metadata( data: np.ndarray, targets: int, metadata: List[Tuple[int, int]] ): return [data[start:end] for start, end in metadata], targets
[docs]@dataclass(frozen=True) class SliceByEventBins: """ Slices an event array along fixed number of bins that each have n_events // bin_count * (1 + overlap) events. This slicing method is good if you recordings have all roughly the same amount of overall activity in the scene and you want an equal number of bins for each recording. Targets are copied. Parameters: bin_count (int): number of bins overlap (float): overlap in proportion of a bin, needs to be smaller than 1. An overlap of 0.1 signifies that the bin will be enlarged by 10%. Amount of bins stays the same. """ bin_count: int overlap: float = 0
[docs] def slice(self, data: np.ndarray, targets: int) -> List[np.ndarray]: metadata = self.get_slice_metadata(data, targets) return self.slice_with_metadata(data, targets, metadata)
[docs] def get_slice_metadata( self, data: np.ndarray, targets: int ) -> List[Tuple[int, int]]: n_events = len(data) spike_count = int(n_events // self.bin_count * (1 + self.overlap)) stride = int(spike_count * (1 - self.overlap)) indices_start = np.arange(self.bin_count) * stride indices_end = indices_start + spike_count return list(zip(indices_start, indices_end))
[docs] @staticmethod def slice_with_metadata( data: np.ndarray, targets: int, metadata: List[Tuple[int, int]] ): return [data[start:end] for start, end in metadata], targets
[docs]@dataclass class SliceAtIndices: """Slices data at the specified event indices. Targets are copied. Parameters: start_indices (list): List of start indices end_indices (list): List of end indices (exclusive) """ start_indices: np.ndarray end_indices: np.ndarray
[docs] def slice(self, data: np.ndarray, targets: int) -> List[np.ndarray]: metadata = self.get_slice_metadata(data, targets) return self.slice_with_metadata(data, targets, metadata)
[docs] def get_slice_metadata( self, data: np.ndarray, targets: int ) -> List[Tuple[int, int]]: return list(zip(self.start_indices, self.end_indices))
[docs] @staticmethod def slice_with_metadata( data: np.ndarray, targets: int, metadata: List[Tuple[int, int]] ): return [data[start:end] for start, end in metadata], targets
[docs]@dataclass class SliceAtTimePoints: """Slice the data at the specified time points. Parameters: tw_start (list): List of start times tw_end (list): List of end times """ start_tw: np.ndarray end_tw: np.ndarray
[docs] def slice(self, data: np.ndarray, targets: int) -> List[np.ndarray]: metadata = self.get_slice_metadata(data, targets) return self.slice_with_metadata(data, targets, metadata)
[docs] def get_slice_metadata( self, data: np.ndarray, targets: int ) -> List[Tuple[int, int]]: t = data["t"] indices_start = np.searchsorted(t, self.start_tw) indices_end = np.searchsorted(t, self.end_tw) return list(zip(indices_start, indices_end))
[docs] @staticmethod def slice_with_metadata( data: np.ndarray, targets: int, metadata: List[Tuple[int, int]] ): return [data[start:end] for start, end in metadata], targets
[docs]def slice_events_by_time( events: np.ndarray, time_window: int, overlap: int = 0, include_incomplete: bool = False, ): return SliceByTime( time_window=time_window, overlap=overlap, include_incomplete=include_incomplete ).slice(events, None)[0]
[docs]def slice_events_by_time_bins(events: np.ndarray, bin_count: int, overlap: float = 0.0): return SliceByTimeBins(bin_count=bin_count, overlap=overlap).slice(events, None)[0]
[docs]def slice_events_by_count( events: np.ndarray, event_count: int, overlap: int = 0, include_incomplete: bool = False, ): return SliceByEventCount( event_count=event_count, overlap=overlap, include_incomplete=include_incomplete ).slice(events, None)[0]
[docs]def slice_events_by_event_bins( events: np.ndarray, bin_count: int, overlap: float = 0.0 ): return SliceByEventBins(bin_count=bin_count, overlap=overlap).slice(events, None)[0]
[docs]def slice_events_at_indices(events: np.ndarray, start_indices, end_indices): return SliceAtIndices(start_indices=start_indices, end_indices=end_indices).slice( events, None )[0]
[docs]def slice_events_at_timepoints( events: np.ndarray, start_tw: np.ndarray, end_tw: np.ndarray ) -> List[np.ndarray]: return SliceAtTimePoints(start_tw=start_tw, end_tw=end_tw).slice(events, None)[0]