Fast data loading with Tonic#

When training spiking neural networks, we typically experience long training times, depending on the number of time steps and training algorithm used. One thing that should not contribute to long training times is the time it takes to load a potentially transformed sample. For a start, let’s measure the time it takes to apply a transform to 100 NMNIST samples without any tricks.

import tonic
import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size
transform = transforms.Compose(
    [
        transforms.Denoise(filter_time=10000),
        transforms.ToFrame(sensor_size=sensor_size, n_time_bins=3),
    ]
)

dataset = tonic.datasets.NMNIST(save_to="./data", train=False, transform=transform)
def load_sample_simple():
    for i in range(100):
        events, target = dataset[i]
%timeit -o load_sample_simple()
1.36 s ± 2.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<TimeitResult : 1.36 s ± 2.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)>
Hide code cell source
print(
    f"Loading time for 60k samples and 200 epochs: ~{int(_.average*600*200/3600)} minutes."
)
Loading time for 60k samples and 200 epochs: ~45 minutes.

Dataloaders with multithreading support#

To speed up things a bit, we can make use of sophisticated dataloaders, which provide support for pre-fetching data, multiple worker threads, batching and other things. Let’s try the PyTorch dataloader. You can find all the supported functionality in the official documentation.

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, num_workers=2, shuffle=True)


def load_sample_pytorch():
    for i, (events, target) in enumerate(iter(dataloader)):
        if i > 99:
            break
load_sample_pytorch = lambda: next(iter(dataloader))
%timeit load_sample_pytorch()
99.4 ms ± 2.95 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Dataset Caching#

Even with a smarter Dataloader, we still do 2 things:

  1. We mostly read from files that are slow to read, maybe because they are in an inefficient binary format or just optimized for disk space.

  2. We apply our deterministic transform every time, for each epoch again. When working with events, we often want to preprocess them into a format that’s more suitable for current accelerators. There’s no need to do that multiple times, since the preprocessing will be deterministic, meaning it will lead to the same result given the same input and transform.

To address these two issues, Tonic provides a DiskCachedDataset. A DiskCachedDataset wraps around your dataset object of choice. Whenever you load a sample, it applies the original transforms to your data and saves the result on disk in an efficient and convenient format. The next time you want to read the same sample, we will just read from that new file instead. In practice, this means that while your first epoch might be similarly slow as before, the following epochs will load much faster.

from tonic import DiskCachedDataset

cached_dataset = DiskCachedDataset(dataset, cache_path="./cache/fast_dataloading")
cached_dataloader = DataLoader(cached_dataset, num_workers=2)


def load_sample_cached():
    for i, (events, target) in enumerate(iter(cached_dataloader)):
        if i > 99:
            break
%timeit -o -r 20 load_sample_cached()
267 ms ± 6.33 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)
<TimeitResult : 267 ms ± 6.33 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)>
Hide code cell source
print(
    f"Loading time for 60k samples and 200 epochs with cache: ~{int(_.average*600*200/3600)} minutes."
)
Loading time for 60k samples and 200 epochs with cache: ~8 minutes.

Augmentations on top of disk-cached data#

If we want to apply stochastic transformations as well, we can pass another set of transforms to the DiskCachedDataset, which will then apply them after reading them from the cache. In the following example, we will convert our cached samples (which are already frames) to tensors and then apply random rotations to the whole recording.

import torch
import torchvision

transform = tonic.transforms.Compose(
    [torch.tensor, torchvision.transforms.RandomRotation([-30, 30])]
)
augmented_dataset = DiskCachedDataset(
    dataset, cache_path="./cache/fast_dataloading2", transform=transform
)
augmented_dataloader = DataLoader(augmented_dataset, num_workers=2)


def load_sample_augmented():
    for i, (events, target) in enumerate(iter(augmented_dataloader)):
        if i > 99:
            break
%timeit -r 20 load_sample_augmented()
319 ms ± 5.86 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)