pytorch hook
아래 게시물을 참고했습니다.
hook 란?
pytorch를 사용할 때, 디버깅을 위해 패키지 중간에 자기가 원하는 코드 끼워넣을 수 있는 기능이다.
hook의 3가지 종류
-
포워드 프리 훅 : 포워드 패스 전에 실행, register_forward_pre_hook
-
포워드 훅 : 포워드 패스 후 실행, register_forward_hook
-
역방향 훅 : 역방향 패스 후 실행, register_full_backward_hook
사용 예시
아래와 같이 layer마다 hook을 사용해서 activation을 볼 수 있다.
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
save_output = SaveOutput()
hook_handles = []
# NOTE : layer를 돌며 Conv2d인 layer에 hook 건다.
for layer in model.modules():
if isinstance(layer, torch.nn.modules.conv.Conv2d):
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
hooks = [
model.backbone[-2].register_forward_hook(
lambda self, input, output: conv_features.append(output)
),
model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
lambda self, input, output: enc_attn_weights.append(output[1])
),
model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
lambda self, input, output: dec_attn_weights.append(output[1])
),
]