diff --git a/labml_nn/transformers/compressive/__init__.py b/labml_nn/transformers/compressive/__init__.py index c9a14a13..3d8edde2 100644 --- a/labml_nn/transformers/compressive/__init__.py +++ b/labml_nn/transformers/compressive/__init__.py @@ -182,7 +182,7 @@ class AttentionReconstructionLoss: c_mem = layer.compress(mem) h = self.norm(layer.norm_self_attn, h) - mem = self.norm(layer.norm_self_attn, h) + mem = self.norm(layer.norm_self_attn, mem) c_mem = self.norm(layer.norm_self_attn, c_mem) return self.loss_func(self.attn(layer.self_attn, h, mem, mem),