Batching multiple event frames#
To decrease the time our GPU waits for new data and sits idle, we’ll increase the batch size next. Event recordings all have different length, even if it’s just microseconds that they are off. In a mini-batch, all the tensors must have the same size. That is why we’ll make use of a helper collate function that pads tensors with zeros so that all the (transformed) recordings in the batch have the same shape.
import tonic
import tonic.transforms as transforms
import torch
from torch.utils.data import DataLoader
torch.manual_seed(1234)
sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.ToFrame(sensor_size=sensor_size, time_window=10000)
dataset = tonic.datasets.NMNIST(
save_to="./data", train=False, transform=frame_transform
)
dataloader_batched = DataLoader(
dataset,
shuffle=True,
batch_size=10,
collate_fn=tonic.collation.PadTensors(batch_first=True),
)
frames, targets = next(iter(dataloader_batched))
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 3
1 import tonic
2 import tonic.transforms as transforms
----> 3 import torch
4 from torch.utils.data import DataLoader
6 torch.manual_seed(1234)
ModuleNotFoundError: No module named 'torch'
By default, the resulting tensor will be in the format (batch, time, channel, height, width).
frames.shape
targets
We can set batch_first=False
in our collate class to change this behaviour as in PyTorch RNN.
dataloader_batched = DataLoader(
dataset,
shuffle=True,
batch_size=10,
collate_fn=tonic.collation.PadTensors(batch_first=False),
)
frames, targets = next(iter(dataloader_batched))
frames.shape