Source code for tonic.prototype.datasets.stmnist

import os
import zipfile
from pathlib import Path
from typing import Any, BinaryIO, Callable, Iterator, Optional, Tuple, Union

import numpy as np
from scipy.io import loadmat
from torchdata.datapipes.iter import (
    FileLister,
    FileOpener,
    Filter,
    IterDataPipe,
    Mapper,
    Saver,
    ZipArchiveLoader,
)

from .utils._dataset import Dataset, Sample
from .utils._utils import check_sha256


[docs]class STMNISTFileReader(IterDataPipe[Sample]): def __init__( self, dp: IterDataPipe[Tuple[str, BinaryIO]], sensor_size: Optional[Tuple[int, int, int]] = (10, 10, 2), dtype: Optional[np.dtype] = np.dtype( [("x", int), ("y", int), ("t", int), ("p", int)] ), ) -> None: self.dp = dp self.dtype = dtype self.sensor_size = sensor_size
[docs] def __iter__(self) -> Iterator[Sample]: for fname in self.dp: yield ( self._mat_to_array(fname), self._get_target(fname), )
def _get_target(self, fname: str) -> int: return int(fname.split(os.sep)[-2]) def _mat_to_array(self, f): # Transposing since the order is (address, event), # but we like (event, address). mat = loadmat(f) spiketrain = mat["spiketrain"].T # Separating coordinates and timestamps. spikes, timestamps = spiketrain[:, :-1], spiketrain[:, -1] # Getting events addresses. # First entry -> Event number. # Second entry -> Event address in [0,100). events_nums, events_addrs = spikes.nonzero() # Mapping addresses to 2D coordinates. # The mapping is (x%address, y//address), from the paper. events = np.zeros((len(events_nums)), dtype=self.dtype) events["x"] = events_addrs % self.sensor_size[0] events["y"] = events_addrs // self.sensor_size[1] # Converting floating point seconds to integer microseconds. events["t"] = (timestamps[events_nums] * 1e6).astype(int) # Converting -1 polarities to 0. events["p"] = np.maximum(spikes[(events_nums, events_addrs)], 0).astype(int) return events
[docs]class STMNIST(Dataset): """`ST-MNIST <https://arxiv.org/abs/2005.04319>`_ Neuromorphic Spiking Tactile MNIST (ST-MNIST) dataset, which comprises handwritten digits obtained by human participants writing on a neuromorphic tactile sensor array. The original paper can be found at https://arxiv.org/abs/2005.04319. Data is provided with the MAT format. Download of the compressed dataset has to be done by the user by accessing https://scho larbank.nus.edu.sg/bitstream/10635/168106/2/STMNIST%20dataset%20NUS%20Tee%20Research%20Group.zi p, where a form has to be completed. Then, the path to the ZIP archive has to be provided to the STMNIST constructor root argument. Events have (xytp) ordering. Parameters: root (string): Parent folder of 'STMNIST/STMNIST dataset NUS Tee Research Group.zip'. The STMNIST folder is related to the Tonic class name and is needed currently. shuffle (bool): Whether to shuffle the dataset. More efficient if done based on file paths. Returns: dp (IterDataPipe[Sample]): Torchdata data pipe that yields a tuple of events (or transformed events) and target. """ _DTYPE = np.dtype([("x", int), ("y", int), ("t", int), ("p", int)]) _SHA256 = "825bb5a64753fff4a2a2c32e3497fa8a951d9c94993e03ba25a057e17d83b884" _FILENAME = "STMNIST dataset NUS Tee Research Group.zip" sensor_size = dict(x=10, y=10, p=2) def __init__( self, root: os.PathLike, keep_compressed: Optional[bool] = False, skip_sha256_check: Optional[bool] = True, shuffle: bool = False, ) -> None: super().__init__( Path(root, self.__class__.__name__), keep_compressed, skip_sha256_check, ) if not skip_sha256_check: check_sha256( fpath=self._root / self._FILENAME, sha256_provided=self._SHA256 ) if not self._check_exists(): assert os.path.isfile( self._root / self._FILENAME ), "Error: root must point to parent folder of STMNIST/STMNIST dataset NUS Tee Research Group.zip." if not keep_compressed: with zipfile.ZipFile(self._root / self._FILENAME, "r") as zip_file: zip_file.extractall(self._root) self.shuffle = shuffle def _check_exists(self): dp = FileLister(str(self._root), recursive=True).filter(self._filter) return len(list(dp)) >= 6953
[docs] def __len__(self) -> int: return 6_953
def _filter(self, fname: str) -> bool: return fname.endswith(".mat") and ("LUT" not in fname) def _datapipe(self) -> IterDataPipe[Sample]: dp = FileLister(str(self._root), recursive=True) if self.shuffle: dp = dp.shuffle(buffer_size=10_000) dp = Filter(dp, self._filter) dp = STMNISTFileReader(dp) return dp