mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 17:57:14 +08:00
🐛 positional encoding buffer
This commit is contained in:
@ -28,7 +28,7 @@ class PositionalEncoding(Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = nn.Dropout(dropout_prob)
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
|
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def __call__(self, x: torch.Tensor):
|
||||||
pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
|
pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
|
||||||
|
|||||||
Reference in New Issue
Block a user