From 06e68d012c3e93a6a6fb64eaabfa0144317cec87 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 27 Sep 2020 16:40:22 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20positional=20encoding=20buffer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/transformers/positional_encoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labml_nn/transformers/positional_encoding.py b/labml_nn/transformers/positional_encoding.py index d3f4f9f1..50679923 100644 --- a/labml_nn/transformers/positional_encoding.py +++ b/labml_nn/transformers/positional_encoding.py @@ -28,7 +28,7 @@ class PositionalEncoding(Module): super().__init__() self.dropout = nn.Dropout(dropout_prob) - self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len)) + self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False) def __call__(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)