Source code for tmeasures.pytorch.transformations

import abc
from typing import Iterable, List

import torch

from tmeasures import Transformation, TransformationSet


[docs]class PyTorchTransformation(Transformation,torch.nn.Module):
[docs] @abc.abstractmethod def parameters(self)->torch.Tensor: pass
[docs]class PyTorchTransformationSet(TransformationSet,Iterable[PyTorchTransformation]):
[docs] def parameter_range(self)->List[torch.Tensor]: parameters = torch.stack([t.parameters() for t in self],dim=0) mi,_ = parameters.min(dim=0) ma,_ = parameters.max(dim=0) return mi, ma
[docs]class IdentityTransformation(PyTorchTransformation):
[docs] def parameters(self) -> torch.Tensor: return 0
def __call__(self, x): return x
[docs]class IdentityTransformationSet(PyTorchTransformationSet): def __init__(self): super().__init__([IdentityTransformation()])