from __future__ import annotations
from torch import nn
import torch
import abc
from typing import Callable, List,MutableMapping,Any, Tuple,Union
from tmeasures.utils import get_all
[docs]class ActivationsModule(nn.Module):
[docs] @abc.abstractmethod
def activation_names(self) -> List[str]:
raise NotImplementedError()
[docs] @abc.abstractmethod
def forward_activations(self, args) -> List[torch.Tensor]:
raise NotImplementedError()
[docs] def n_activations(self):
return len(self.activation_names())
#
ActivationFilter = Callable[[ActivationsModule,str],bool]
[docs]class FilteredActivationsModule(ActivationsModule):
def __init__(self, inner_model:ActivationsModule, activations_filter:ActivationFilter):
super().__init__()
self.inner_model = inner_model
original_names = inner_model.activation_names()
self.names = [name for name in original_names if activations_filter(inner_model,name)]
assert len(self.names)>0
self.indices = [original_names.index(n) for n in self.names]
[docs] def forward(self):
return self.inner_model.forward()
[docs] def activation_names(self) -> List[str]:
return self.names
[docs] def eval(self):
self.inner_model.eval()
[docs] @abc.abstractmethod
def forward_activations(self, args) -> List[torch.Tensor]:
activations = self.inner_model.forward_activations(args)
activations = [activations[i] for i in self.indices]
return activations
[docs]def intersect_lists(a:List,b:List):
return list(set(a) & set(b))
[docs]class DuplicateKeyError(ValueError):
def __init__(self, *args: object) -> None:
super().__init__(*args)
Input = Union[MutableMapping,Any]
[docs]def flatten_dict(d_or_val:Input, prefix='', sep='/', allow_repeated=False)->dict[str,Any]:
print(d_or_val)
if isinstance(d_or_val, MutableMapping):
result = {}
if prefix =="":
prefix_sep = ""
else:
prefix_sep = prefix+sep
for k, v in d_or_val.items():
new_prefix = prefix_sep+str(k)
flattened_child = flatten_dict(v,new_prefix,sep)
if not allow_repeated :
intersection = intersect_lists(flattened_child.keys(),d_or_val.keys())
if len(intersection)>0:
raise DuplicateKeyError(intersection)
result.update(flattened_child)
return result
else :
return { prefix : d_or_val }
[docs]def flatten_dict_list(d_or_val:Input,key="",full_name=True)->List[Tuple[str,Any]]:
if isinstance(d_or_val, MutableMapping):
result = []
for k, v in d_or_val.items():
if full_name:
new_key = key + "/" + k
else:
new_key = k
flattened_child = flatten_dict_list(v,key=new_key)
result+=flattened_child
return result
else :
return [(key, d_or_val)]
[docs]def named_children_deep(m: torch.nn.Module):
children = dict(m.named_children())
if children == {}:
return m
else:
output = {}
for name, child in children.items():
if name.isnumeric():
key = child.__class__.__name__+"_"+name
else:
key=name
if key in output:
raise DuplicateKeyError
try:
output[key] = named_children_deep(child)
except TypeError:
output[key] = named_children_deep(child)
return output
[docs]class AutoActivationsModule(ActivationsModule):
def __init__(self,module:nn.Module,full_name=True,filter:Callable=lambda x: True) -> None:
super().__init__()
self.reset_values()
self.module = module
self.children_tree = named_children_deep(module)
self.names,self.activations = zip(*flatten_dict_list(self.children_tree,full_name=full_name))
filter_indices = [i for i,a in enumerate(self.activations) if filter(a)]
self.names=get_all(self.names,filter_indices)
self.activations=get_all(self.activations,filter_indices)
self.register_hooks(self.activations)
[docs] def register_hooks(self,activations:List[nn.Module]):
def store_activation(index):
def hook(model, input, output):
self.values[index] = output.detach()
return hook
for i,v in enumerate(activations):
v.register_forward_hook(store_activation(i))
[docs] def reset_values(self):
self.values = {}
[docs] def forward_activations(self, args) -> List[torch.Tensor]:
'''
This function is not thread safe.
'''
# clear values
self.reset_values()
# call model, hooks are executed
self.module.forward(args)
# Check that all original activations are present in the values
nv,na = len(self.values),len(self.activations)
if nv!=na:
indices_a = set(list(range(len(self.activations))))
indices_v = set(list(self.values.keys()))
diff = indices_a-indices_v
missing_names = [self.activations[i] for i in diff]
assert nv==na, f"Values output by network ({nv}) dont match original number of activations ({na}). Values: \n {missing_names}"
# transform to list and ensure order of values is same as order of activation names
activations_list = [self.values[i] for i in range(len(self.activations))]
return activations_list
[docs] def activation_names(self) -> List[str]:
return self.names