mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 02:41:38 +08:00
48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import math
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from labml_helpers.module import Module
|
|
|
|
|
|
class PositionalEncoding(Module):
|
|
def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
|
|
super().__init__()
|
|
self.dropout = nn.Dropout(dropout_prob)
|
|
|
|
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
|
|
|
|
def __call__(self, x: torch.Tensor):
|
|
pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
|
|
x = x + pe
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
def get_positional_encoding(d_model: int, max_len: int = 5000):
|
|
encodings = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
|
two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
|
|
div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
|
|
encodings[:, 0::2] = torch.sin(position * div_term)
|
|
encodings[:, 1::2] = torch.cos(position * div_term)
|
|
encodings = encodings.unsqueeze(1).requires_grad_(False)
|
|
|
|
return encodings
|
|
|
|
|
|
def _test_positional_encoding():
|
|
plt.figure(figsize=(15, 5))
|
|
pe = get_positional_encoding(20, 100)
|
|
plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
|
|
plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
|
|
plt.title("Positional encoding")
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
_test_positional_encoding()
|