📊 tracker debug attention

This commit is contained in:
Varuna Jayasiri
2020-09-03 09:32:02 +05:30
parent 4e15860907
commit 22215f8d07
3 changed files with 8 additions and 3 deletions

View File

@ -86,9 +86,6 @@ class TransformerLayer(Module):
ff = self.feed_forward(z) ff = self.feed_forward(z)
x = x + self.dropout(ff) x = x + self.dropout(ff)
# guard(x.shape, attn_self.shape, attn_src.shape, ff.shape,
# '_batch_size', '_seq_len', 'd_model')
return x return x

View File

@ -2,6 +2,7 @@ import math
from typing import Optional from typing import Optional
import torch import torch
from labml import tracker
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
@ -65,6 +66,7 @@ class MultiHeadAttention(Module):
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1] assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
scores = scores.masked_fill(mask == 0, -1e9) scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=1) attn = F.softmax(scores, dim=1)
tracker.debug('attn', attn)
attn = self.dropout(attn) attn = self.dropout(attn)
x = torch.einsum("ijbh,jbhd->ibhd", attn, value) x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

View File

@ -0,0 +1,6 @@
import torch
def subsequent_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
return mask