diff --git a/labml_nn/transformers/compressive/__init__.py b/labml_nn/transformers/compressive/__init__.py index d0a22c83..c9a14a13 100644 --- a/labml_nn/transformers/compressive/__init__.py +++ b/labml_nn/transformers/compressive/__init__.py @@ -169,12 +169,22 @@ class AttentionReconstructionLoss: # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ return torch.einsum("ijbh,jbhd->ibhd", attn, value) + def norm(self, ln: nn.LayerNorm, x: torch.Tensor): + weight = ln.weight.detach() if ln.weight is not None else None + bias = ln.bias.detach() if ln.bias is not None else None + + return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps) + def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): h = h.detach() mem = mem.detach() c_mem = layer.compress(mem) + h = self.norm(layer.norm_self_attn, h) + mem = self.norm(layer.norm_self_attn, h) + c_mem = self.norm(layer.norm_self_attn, c_mem) + return self.loss_func(self.attn(layer.self_attn, h, mem, mem), self.attn(layer.self_attn, h, c_mem, c_mem))