✍️ mha english

This commit is contained in:
Varuna Jayasiri
2021-02-01 07:28:33 +05:30
parent 53128f5679
commit 5cd2b8701b
2 changed files with 24 additions and 24 deletions

View File

@ -44,8 +44,8 @@ class PrepareForMultiHeadAttention(Module):
def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
# We apply the linear transformation of the last dimension and splits that into
# the heads
# We apply the linear transformation to the last dimension and split that into
# the heads.
head_shape = x.shape[:-1]
# Linear transform
@ -66,7 +66,7 @@ class MultiHeadAttention(Module):
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of
In simple terms, it finds keys that matches the query, and gets the values of
those keys.
It uses dot-product of query and key as the indicator of how matching they are.
@ -74,7 +74,7 @@ class MultiHeadAttention(Module):
This is done to avoid large dot-product values causing softmax to
give very small gradients when $d_k$ is large.
Softmax is calculate along the axis of of the sequence (or time).
Softmax is calculated along the axis of of the sequence (or time).
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
@ -105,7 +105,7 @@ class MultiHeadAttention(Module):
# Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k)
# We store attentions so that it can used for logging, or other computations if needed
# We store attentions so that it can be used for logging, or other computations if needed
self.attn = None
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@ -125,10 +125,10 @@ class MultiHeadAttention(Module):
mask: Optional[torch.Tensor] = None):
"""
`query`, `key` and `value` are the tensors that store
collection of*query*, *key* and *value* vectors.
collection of *query*, *key* and *value* vectors.
They have shape `[seq_len, batch_size, d_model]`.
`mask` has shape `[seq_len, seq_len, batch_size]` and indicates
`mask` has shape `[seq_len, seq_len, batch_size]` and
`mask[i, j, b]` indicates whether for batch `b`,
query at position `i` has access to key-value at position `j`.
"""
@ -139,20 +139,20 @@ class MultiHeadAttention(Module):
if mask is not None:
# `mask` has shape `[seq_len, seq_len, batch_size]`,
# where first dimension is the query dimension.
# If the query dimension is equal to $1$ it will be broadcasted
# If the query dimension is equal to $1$ it will be broadcasted.
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
# Same mask applied to all heads.
mask = mask.unsqueeze(-1)
# Prepare `query`, `key` and `value` for attention computation
# These will then have shape `[seq_len, batch_size, heads, d_k]`
# Prepare `query`, `key` and `value` for attention computation.
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
query = self.query(query)
key = self.key(key)
value = self.value(value)
# Compute attention scores $Q K^\top$
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
# Compute attention scores $Q K^\top$.
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$