mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	🚧 layer norm in attention rec loss
This commit is contained in:
		| @ -169,12 +169,22 @@ class AttentionReconstructionLoss: | |||||||
|         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ |         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ | ||||||
|         return torch.einsum("ijbh,jbhd->ibhd", attn, value) |         return torch.einsum("ijbh,jbhd->ibhd", attn, value) | ||||||
|  |  | ||||||
|  |     def norm(self, ln: nn.LayerNorm, x: torch.Tensor): | ||||||
|  |         weight = ln.weight.detach() if ln.weight is not None else None | ||||||
|  |         bias = ln.bias.detach() if ln.bias is not None else None | ||||||
|  |  | ||||||
|  |         return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps) | ||||||
|  |  | ||||||
|     def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): |     def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): | ||||||
|         h = h.detach() |         h = h.detach() | ||||||
|         mem = mem.detach() |         mem = mem.detach() | ||||||
|  |  | ||||||
|         c_mem = layer.compress(mem) |         c_mem = layer.compress(mem) | ||||||
|  |  | ||||||
|  |         h = self.norm(layer.norm_self_attn, h) | ||||||
|  |         mem = self.norm(layer.norm_self_attn, h) | ||||||
|  |         c_mem = self.norm(layer.norm_self_attn, c_mem) | ||||||
|  |  | ||||||
|         return self.loss_func(self.attn(layer.self_attn, h, mem, mem), |         return self.loss_func(self.attn(layer.self_attn, h, mem, mem), | ||||||
|                               self.attn(layer.self_attn, h, c_mem, c_mem)) |                               self.attn(layer.self_attn, h, c_mem, c_mem)) | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri