mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
7 lines
142 B
Python
7 lines
142 B
Python
import torch
|
|
|
|
|
|
def subsequent_mask(seq_len):
|
|
mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
|
|
return mask
|