Utilities for Transformer

10import torch

Subsequent mask to mask out data from future (subsequent) time steps

13def subsequent_mask(seq_len):
17    mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
18    return mask