memo: PyTorch | Hooks

Source video: PyTorch Hooks Explained - In-depth Tutorial - Elliot Waite

Hooks for tensor

1
a.register_hook(hook_func)

This will add a property _backward_hooks for the tensor a. And hooks for tensors only take effect when back-propagating (when the gradient is calculated).

hook_func can be normal function or a lambda function, which takes as input the gradient grad for this tensor a coming from the last node, and pass the current gradient to the later backwards graph.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def hook_func(grad: Tensor):
  print(grad)
  return grad + 1 # the grad passed to the next operation increased by 1.

a.register_hook(hook_func)

# Second hook function
a.register_hook(lambda grad: print(grad))

# Third hook function: changed grad
a.register_hook(lambda grad: grad * 100)

# Fourth hook function: save gradient for an intermediate node
a.retain_grad()
  • _backward_hooks is an OrderDict so it can contain multiple functions, and they’ll be executed according to their definition sequence.

  • Registering a hook for an intermediate node tensor will notify the corresponding tensor in the backwards graph.

    Inside the associated node on the backwards graph, there will be an additional property: pre_hooks list, which will call the hook property a._backward_hooks of that tensor, before the grad getting into the method backward.

    So the hook will be executed ahead of backward property during the back-propagating. backward will use the gradients returned from the hooks.

  • However, when setting a hook for a leaf node, the hook function will only add the hook func into the _backward_hooks OrderDict of that leaf node.

    And the associated AccumulateGrad node of that leaf node will check if the leaf node has hook function needed to be executed before assigning grad from previous calculations.

  • Each hook function has a handle index, which will be returned after executing the hook function, e.g., h = c.register_hook(hook_func) A hook can be removed via the handle index: h.remove()

Caveat: Change grad in-place in the hook functions may affect other tensors’ gradients, so later backward pass will be mess up.

  • For example, as for the grad_fn of operation e = c+d, the output gradients for tensor c and d are supposed to be the same. If the grad of d has changed, the grad of c will also changed.


e.g. Gradient clipper ¹

Clamp the gradient of each tensor in a certain range by registering a hook for each parameter.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
from torchvision import models

def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
  for parameter in model.parameters():
    # in-place changing the gradient
    parameter.register_hook(lambda grad: grad.clamp_(-val, val))

  return model

resnet_clipped = gradient_clipper(models.resnet50(), val=0.01)
dummy_input = torch.ones(1, 3, 224, 224)
pred = resnet_clipped(dummy_input)
loss = pred.log().mean()
loss.backward()

print(resnet_clipped.fc.bias.grad[:25])

Hooks for modules

Hooks registered for modules can be automatically triggered before or after a nn.module.forward is called (even if a layer), so a hook can modify the input and output tensors of a nn.module

  1. Hooks before forward register_forward_pre_hook(hook_func), where the hook_func can access the module and its positional input.

    1
    2
    3
    4
    5
    
    def hook_func_pre_forward(module: nn.Module, inputs: Tensor):
      a, b = inputs
      return a+2, b
    
    myModel.register_forward_pre_hook(hook_func_pre_forward)
    
  2. Hooks after forward: register_forward_hook(hook_func), where the hook_func will recieve 3 arguments: the module, its input, and its output.

    1
    2
    3
    4
    
    def hook_func_forward(module: nn.Module, inputs: Tensor, output: Tensor):
      return output + 10
    
    myModel.register_forward_hook(hook_func_forward)
    
  3. Hooks after backward: register_backward_hook() has been deprecated in favor of register_full_backward_hook()


e.g. Inspect a model ¹

Printing the shape of output tensors after each layer by registering a hook for each layer in an external wrapper, rather than adding print inside the model.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch import nn, Tensor
from torchvision import models

class VerboseExecution(nn.Module):
  def __init__(self, model: nn.Module):
    super().__init__()
    self.model = model

    # Register a hook for each layer
    for name, module in self.model.named_children():
      # conv1, bn1, relu, maxpool, layer1, ...
      module.__name__ = name
      module.register_forward_hook(self.print_shape())

  def print_shape(self):
    def hook_func(module, inputs, output):
     print(f"{module.__name__}: {output.shape}")
    return hook_func

  def forward(self, x: Tensor) -> Tensor:
    return self.model(x)

# Print intermediate shape in ResNet50
resnet_verbose = VerboseExecution(models.resnet50())
dummy_input = torch.ones(1, 3, 224, 224)
_ = resnet_verbose(dummy_input)

e.g. Extract feature maps ¹

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch import nn, Tensor
from torchvision import models
from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
  def __init__(self, model: nn.Module, layer_ls: Iterable[str]):
    super().__init__()
    self.model = model
    self.layers_ex = layer_ls
    # Define a dict to store feature maps
    self._features = {layer: torch.empty(0) for layer in layer_ls}

    for name in layer_ls:
      # Pick out the selected layers by their names from a dictionary
      layer = dict([*self.model.named_modules()])[name]

      # Register a hook for each layer
      layer.register_forward_hook(self.save_outputs(name))

  def save_outputs(self, layer_name: str) -> Callable:
    def hook_func(module, inputs, output):
      self._features[layer_name] = output
    return hook_func

  def forward(self, x: Tensor) -> Dict[str, Tensor]:
    _ = self.model(x)
    return self._features

# Extract feature maps at each level before "avgpool" and "fc"
resnet50 = models.resnet50()
resnet_features = FeatureExtractor(
        resnet50, layer_ls = list(resnet50._modules)[:-2] )
dummy_input = torch.ones(1, 3, 224, 224)

feature_maps = resnet_features(dummy_input)

print({name: output.shape for name, output in feature_maps.items()})

Reference