mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 18:27:03 +08:00
🚧 layer norm in attention rec loss
This commit is contained in:
@ -169,12 +169,22 @@ class AttentionReconstructionLoss:
|
||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||
return torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
||||
|
||||
def norm(self, ln: nn.LayerNorm, x: torch.Tensor):
|
||||
weight = ln.weight.detach() if ln.weight is not None else None
|
||||
bias = ln.bias.detach() if ln.bias is not None else None
|
||||
|
||||
return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)
|
||||
|
||||
def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):
|
||||
h = h.detach()
|
||||
mem = mem.detach()
|
||||
|
||||
c_mem = layer.compress(mem)
|
||||
|
||||
h = self.norm(layer.norm_self_attn, h)
|
||||
mem = self.norm(layer.norm_self_attn, h)
|
||||
c_mem = self.norm(layer.norm_self_attn, c_mem)
|
||||
|
||||
return self.loss_func(self.attn(layer.self_attn, h, mem, mem),
|
||||
self.attn(layer.self_attn, h, c_mem, c_mem))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user