📊 tracker debug attention

This commit is contained in:
Varuna Jayasiri
2020-09-03 09:32:02 +05:30
parent 4e15860907
commit 22215f8d07
3 changed files with 8 additions and 3 deletions

View File

@ -0,0 +1,6 @@
import torch
def subsequent_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
return mask