Files
Varuna Jayasiri 22215f8d07 📊 tracker debug attention
2020-09-03 09:32:02 +05:30

7 lines
142 B
Python

import torch
def subsequent_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
return mask