Source code for tonic.prototype.slicers

from dataclasses import dataclass

import numpy as np
import torchdata.datapipes.iter as pipes
from torchdata.datapipes import functional_datapipe


[docs]@dataclass @functional_datapipe("slice_by_time") class SliceByTime(pipes.IterDataPipe): """Slices an event array along fixed time window and overlap size. The number of bins depends on the length of the recording. Only works on numpy event arrays that contain a 't' or 'ts' field. > <overlap> >| window1 | > | window2 | Parameters: time_window (int): time for window length (same unit as event timestamps) overlap (int): overlap (same unit as event timestamps) include_incomplete (bool): include the last incomplete slice that has shorter time """ source_dp: pipes.IterDataPipe dt: float overlap: float = 0.0 include_incomplete: bool = False
[docs] def __iter__(self): it = iter(self.source_dp) while True: try: events = next(it) if "t" in events.dtype.names: t = events["t"] elif "ts" in events.dtype.names: t = events["ts"] stride = self.dt - self.overlap assert stride > 0 rounding_fn = np.ceil if self.include_incomplete else np.floor n_slices = int(rounding_fn(((t[-1] - t[0]) - self.dt) / stride) + 1) n_slices = max(n_slices, 1) # for strides larger than recording time window_start_times = np.arange(n_slices) * stride + t[0] window_end_times = window_start_times + self.dt indices_start = np.searchsorted(t, window_start_times)[:n_slices] indices_end = np.searchsorted(t, window_end_times)[:n_slices] for start, end in zip(indices_start, indices_end): yield events[start:end] except StopIteration: return
[docs]@dataclass @functional_datapipe("slice_by_event_count") class SliceByEventCount(pipes.IterDataPipe): """Slices data and targets along a fixed number of events and overlap size. The number of bins depends on the amount of events in the recording. Only works on numpy event arrays. Parameters: event_count (int): number of events for each bin overlap (int): overlap in number of events include_incomplete (bool): include the last incomplete slice that has fewer events """ source_dp: pipes.IterDataPipe n: int overlap: int = 0 include_incomplete: bool = False
[docs] def __iter__(self): it = iter(self.source_dp) while True: try: events = next(it) n_events = len(events) event_count = min(self.n, n_events) stride = self.n - self.overlap if stride <= 0: raise Exception( "Inferred stride <= 0. Increase n or decrease overlap." ) rounding_fn = np.ceil if self.include_incomplete else np.floor n_slices = int(rounding_fn((n_events - event_count) / stride) + 1) for start in (np.arange(n_slices) * stride).astype(int): yield events[start : start + event_count] except StopIteration: return