mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	🚧 layer norm in attention rec loss
This commit is contained in:
		| @ -169,12 +169,22 @@ class AttentionReconstructionLoss: | ||||
|         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ | ||||
|         return torch.einsum("ijbh,jbhd->ibhd", attn, value) | ||||
|  | ||||
|     def norm(self, ln: nn.LayerNorm, x: torch.Tensor): | ||||
|         weight = ln.weight.detach() if ln.weight is not None else None | ||||
|         bias = ln.bias.detach() if ln.bias is not None else None | ||||
|  | ||||
|         return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps) | ||||
|  | ||||
|     def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): | ||||
|         h = h.detach() | ||||
|         mem = mem.detach() | ||||
|  | ||||
|         c_mem = layer.compress(mem) | ||||
|  | ||||
|         h = self.norm(layer.norm_self_attn, h) | ||||
|         mem = self.norm(layer.norm_self_attn, h) | ||||
|         c_mem = self.norm(layer.norm_self_attn, c_mem) | ||||
|  | ||||
|         return self.loss_func(self.attn(layer.self_attn, h, mem, mem), | ||||
|                               self.attn(layer.self_attn, h, c_mem, c_mem)) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri