transpose with \top

This commit is contained in:
Varuna Jayasiri
2021-01-10 12:16:50 +05:30
parent 8ee70da198
commit 9102883e1f
3 changed files with 30 additions and 29 deletions

View File

@ -64,7 +64,7 @@ class MultiHeadAttention(Module):
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of
those keys.
@ -115,7 +115,7 @@ class MultiHeadAttention(Module):
This method can be overridden for other variations like relative attention.
"""
# Calculate $Q K^T$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def __call__(self, *,
@ -151,11 +151,11 @@ class MultiHeadAttention(Module):
key = self.key(key)
value = self.value(value)
# Compute attention scores $Q K^T$
# Compute attention scores $Q K^\top$
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= self.scale
# Apply mask
@ -163,7 +163,7 @@ class MultiHeadAttention(Module):
scores = scores.masked_fill(mask == 0, -1e9)
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = self.softmax(scores)
# Save attentions if debugging
@ -173,7 +173,7 @@ class MultiHeadAttention(Module):
attn = self.dropout(attn)
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# Save attentions for any other calculations