import torch def subsequent_mask(seq_len): mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1) return mask