from torch.utils.data import Dataset, IterableDataset, DataLoader
import abc
from .transformations import PyTorchTransformationSet
import torch
[docs]class Dataset2D(Dataset):
@property
@abc.abstractmethod
def len0(self):
pass
@property
@abc.abstractmethod
def len1(self):
pass
@property
@abc.abstractmethod
def T(self):
pass
[docs] def len(self):
return self.len0 * self.len1
def __getitem__(self, idx):
if len(idx) == 1:
i, j = idx // self.len1, idx % self.len1
return self.getitem2d(i, j)
elif len(idx) == 2:
i, j = idx
return self.getitem2d(i, j)
else:
raise ValueError(f"Invalid index: {idx}.")
[docs] @abc.abstractmethod
def getitem2d(self, i, j):
pass
[docs] def row_dataset(self, row: int):
return RowDataset(self,row)
[docs]class STDataset(Dataset2D):
def __init__(self, dataset: Dataset, transformations: PyTorchTransformationSet,device=torch.device("cpu")):
"""
@param dataset: Non iterable dataset from which to draw samples
@param transformations: set of transformations to apply to samples
"""
if isinstance(dataset, IterableDataset):
raise ValueError(
f"{IterableDataset} not supported; must specify a map-style dataset (https://pytorch.org/docs/stable/data.html#dataset-types)")
self.dataset = dataset
self.transformations = transformations
self.device=device
[docs] def len_dataset(self):
return len(self.dataset)
[docs]class RowDataset(Dataset):
def __init__(self, d: Dataset2D, row: int):
self.d = d
self.row = row
def __getitem__(self, item):
return self.d.getitem2d(self.row, item)
def __len__(self):
return self.d.len1