Source code for tmeasures.pytorch.transformations
import abc
from typing import Iterable, List
import torch
from tmeasures import Transformation, TransformationSet
[docs]class PyTorchTransformationSet(TransformationSet,Iterable[PyTorchTransformation]):
[docs] def parameter_range(self)->List[torch.Tensor]:
parameters = torch.stack([t.parameters() for t in self],dim=0)
mi,_ = parameters.min(dim=0)
ma,_ = parameters.max(dim=0)
return mi, ma
[docs]class IdentityTransformationSet(PyTorchTransformationSet):
def __init__(self):
super().__init__([IdentityTransformation()])