import os
from typing import Callable, Optional
import numpy as np
from tonic.dataset import Dataset
from tonic.io import read_mnist_file
[docs]class NMNIST(Dataset):
"""`N-MNIST <https://www.garrickorchard.com/datasets/n-mnist>`_
Events have (xytp) ordering.
::
@article{orchard2015converting,
title={Converting static image datasets to spiking neuromorphic datasets using saccades},
author={Orchard, Garrick and Jayawant, Ajinkya and Cohen, Gregory K and Thakor, Nitish},
journal={Frontiers in neuroscience},
volume={9},
pages={437},
year={2015},
publisher={Frontiers}
}
Parameters:
save_to (string): Location to save files to on disk.
train (bool): If True, uses training subset, otherwise testing subset.
first_saccade_only (bool): If True, only work with events of the first of three saccades.
Results in about a third of the events overall.
stabilize (bool): If True, it stabilizes egomotion of the saccades, centering the digit.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
transforms (callable, optional): A callable of transforms that is applied to both data and
labels at the same time.
"""
base_url = "https://data.mendeley.com/public-files/datasets/468j46mzdv/files/"
train_url = base_url + "39c25547-014b-4137-a934-9d29fa53c7a0/file_downloaded"
train_filename = "train.zip"
train_md5 = "20959b8e626244a1b502305a9e6e2031"
train_folder = "Train"
test_url = base_url + "05a4d654-7e03-4c15-bdfa-9bb2bcbea494/file_downloaded"
test_filename = "test.zip"
test_md5 = "69ca8762b2fe404d9b9bad1103e97832"
test_folder = "Test"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
sensor_size = (34, 34, 2)
dtype = np.dtype([("x", int), ("y", int), ("t", int), ("p", int)])
ordering = dtype.names
def __init__(
self,
save_to: str,
train: bool = True,
first_saccade_only: bool = False,
stabilize: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(
save_to,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)
self.train = train
self.first_saccade_only = first_saccade_only
self.stabilize = stabilize
if train:
self.filename = self.train_filename
self.url = self.train_url
self.file_md5 = self.train_md5
self.folder_name = self.train_folder
else:
self.filename = self.test_filename
self.url = self.test_url
self.file_md5 = self.test_md5
self.folder_name = self.test_folder
if not self._check_exists():
self.download()
file_path = os.path.join(self.location_on_system, self.folder_name)
for path, dirs, files in os.walk(file_path):
files.sort()
for file in files:
if file.endswith("bin"):
self.data.append(path + "/" + file)
label_number = int(path[-1])
self.targets.append(label_number)
[docs] def __getitem__(self, index):
"""
Returns:
a tuple of (events, target) where target is the index of the target
class.
"""
events = read_mnist_file(self.data[index], dtype=self.dtype)
if self.first_saccade_only:
events = events[events["t"] < 1e5]
if self.stabilize:
events = stabilize(events)
target = self.targets[index]
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
events, target = self.transforms(events, target)
return events, target
[docs] def __len__(self) -> int:
return len(self.data)
def _check_exists(self) -> bool:
return (
self._is_file_present()
and self._folder_contains_at_least_n_files_of_type(10000, ".bin")
)
[docs]def stabilize(events):
"""
Stabilize digits, code ported from https://www.garrickorchard.com/datasets/n-mnist
Returns:
stabilized events, removing the egomotion caused by saccades.
"""
stab_x = np.asarray(events["x"], dtype=np.float64)
stab_y = np.asarray(events["y"], dtype=np.float64)
# original code might result in a small offset, fixed manually
x_off = 4
y_off = 2
saccade_1_index = events["t"] <= 105e3
stab_x[saccade_1_index] = x_off + stab_x[saccade_1_index] - \
3.5*events["t"][saccade_1_index]/105e3
stab_y[saccade_1_index] = y_off + stab_y[saccade_1_index] - \
7*events["t"][saccade_1_index]/105e3
saccade_2_index = (events["t"] > 105e3) * (events["t"] <= 210e3)
stab_x[saccade_2_index] = x_off + stab_x[saccade_2_index] - \
3.5 - 3.5*(events["t"][saccade_2_index] - 105e3)/105e3
stab_y[saccade_2_index] = y_off + stab_y[saccade_2_index] - \
7 + 7*(events["t"][saccade_2_index] - 105e3)/105e3
saccade_3_index = (events["t"] > 210e3)
stab_x[saccade_3_index] = x_off + stab_x[saccade_3_index] - \
7 + 7*(events["t"][saccade_3_index]-210e3)/105e3
# events["y"] remains almonst unchaged because it is a horizontal saccade
stab_y[saccade_3_index] = y_off + stab_y[saccade_3_index]
events["x"] = np.asarray(np.round(stab_x), dtype=np.int64)
events["y"] = np.asarray(np.round(stab_y), dtype=np.int64)
nulls = (stab_x < 0) + (stab_y < 0) + (stab_x > 33) + (stab_y > 33)
events = events[nulls == 0]
return events