知了传课DRF+Vue实现APl自动化测试平台
特征提取
通常,我们希望从一个预先训练好的网络中生成特性,然后用它们来完成另一个任务(例如分类、相似度搜索等)。使用 hook,我们可以提取特征,而不需要重新创建现有模型或以任何方式修改它。
from typing import Dict, Iterable, Callable class FeatureExtractor(nn.Module): def __init__(self, model: nn.Module, layers: Iterable[str]): super().__init__() self.model = model self.layers = layers self._features = {layer: torch.empty(0) for layer in layers} for layer_id in layers: layer = dict([*self.model.named_modules()])[layer_id] layer.register_forward_hook(self.save_outputs_hook(layer_id)) def save_outputs_hook(self, layer_id: str) -> Callable: def fn(_, __, output): self._features[layer_id] = output return fn def forward(self, x: Tensor) -> Dict[str, Tensor]: _ = self.model(x) return self._features