From 9636cfef0357e40e1c75e215dc4e4b75d72d02e8 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 17 Feb 2021 17:37:16 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20layer=20norm=20in=20attention=20?= =?UTF-8?q?rec=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/transformers/compressive/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/labml_nn/transformers/compressive/__init__.py b/labml_nn/transformers/compressive/__init__.py index d0a22c83..c9a14a13 100644 --- a/labml_nn/transformers/compressive/__init__.py +++ b/labml_nn/transformers/compressive/__init__.py @@ -169,12 +169,22 @@ class AttentionReconstructionLoss: # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ return torch.einsum("ijbh,jbhd->ibhd", attn, value) + def norm(self, ln: nn.LayerNorm, x: torch.Tensor): + weight = ln.weight.detach() if ln.weight is not None else None + bias = ln.bias.detach() if ln.bias is not None else None + + return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps) + def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): h = h.detach() mem = mem.detach() c_mem = layer.compress(mem) + h = self.norm(layer.norm_self_attn, h) + mem = self.norm(layer.norm_self_attn, h) + c_mem = self.norm(layer.norm_self_attn, c_mem) + return self.loss_func(self.attn(layer.self_attn, h, mem, mem), self.attn(layer.self_attn, h, c_mem, c_mem))