diff --git a/labml_nn/transformers/positional_encoding.py b/labml_nn/transformers/positional_encoding.py index d3f4f9f1..50679923 100644 --- a/labml_nn/transformers/positional_encoding.py +++ b/labml_nn/transformers/positional_encoding.py @@ -28,7 +28,7 @@ class PositionalEncoding(Module): super().__init__() self.dropout = nn.Dropout(dropout_prob) - self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len)) + self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False) def __call__(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)