🚧 layer norm in attention rec loss

This commit is contained in:
Varuna Jayasiri
2021-02-17 17:37:16 +05:30
parent c1ab9d8589
commit 9636cfef03

View File

@ -169,12 +169,22 @@ class AttentionReconstructionLoss:
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
return torch.einsum("ijbh,jbhd->ibhd", attn, value)
def norm(self, ln: nn.LayerNorm, x: torch.Tensor):
weight = ln.weight.detach() if ln.weight is not None else None
bias = ln.bias.detach() if ln.bias is not None else None
return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)
def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):
h = h.detach()
mem = mem.detach()
c_mem = layer.compress(mem)
h = self.norm(layer.norm_self_attn, h)
mem = self.norm(layer.norm_self_attn, h)
c_mem = self.norm(layer.norm_self_attn, c_mem)
return self.loss_func(self.attn(layer.self_attn, h, mem, mem),
self.attn(layer.self_attn, h, c_mem, c_mem))