Batching multiple event frames#

To decrease the time our GPU waits for new data and sits idle, we’ll increase the batch size next. Event recordings all have different length, even if it’s just microseconds that they are off. In a mini-batch, all the tensors must have the same size. That is why we’ll make use of a helper collate function that pads tensors with zeros so that all the (transformed) recordings in the batch have the same shape.

import tonic
import tonic.transforms as transforms
import torch
from torch.utils.data import DataLoader

torch.manual_seed(1234)

sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.ToFrame(sensor_size=sensor_size, time_window=10000)

dataset = tonic.datasets.NMNIST(
    save_to="./data", train=False, transform=frame_transform
)

dataloader_batched = DataLoader(
    dataset,
    shuffle=True,
    batch_size=10,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
)

frames, targets = next(iter(dataloader_batched))

By default, the resulting tensor will be in the format (batch, time, channel, height, width).

frames.shape
torch.Size([10, 31, 2, 34, 34])
targets
tensor([6, 3, 8, 8, 7, 8, 1, 1, 7, 7])

We can set batch_first=False in our collate class to change this behaviour as in PyTorch RNN.

dataloader_batched = DataLoader(
    dataset,
    shuffle=True,
    batch_size=10,
    collate_fn=tonic.collation.PadTensors(batch_first=False),
)

frames, targets = next(iter(dataloader_batched))
frames.shape
torch.Size([31, 10, 2, 34, 34])