📊 tracker debug attention

This commit is contained in:
Varuna Jayasiri
2020-09-03 09:32:02 +05:30
parent 4e15860907
commit 22215f8d07
3 changed files with 8 additions and 3 deletions

View File

@ -86,9 +86,6 @@ class TransformerLayer(Module):
ff = self.feed_forward(z)
x = x + self.dropout(ff)
# guard(x.shape, attn_self.shape, attn_src.shape, ff.shape,
# '_batch_size', '_seq_len', 'd_model')
return x

View File

@ -2,6 +2,7 @@ import math
from typing import Optional
import torch
from labml import tracker
from torch import nn as nn
from torch.nn import functional as F
@ -65,6 +66,7 @@ class MultiHeadAttention(Module):
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=1)
tracker.debug('attn', attn)
attn = self.dropout(attn)
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

View File

@ -0,0 +1,6 @@
import torch
def subsequent_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
return mask