home transformers
10import torch
13def subsequent_mask(seq_len):
17 mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1) 18 return mask