diff --git a/docs/transformers/mha.html b/docs/transformers/mha.html
index 8f2b0f6f..6c19e8f1 100644
--- a/docs/transformers/mha.html
+++ b/docs/transformers/mha.html
@@ -164,8 +164,8 @@ This is used to transform key, query, and Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model].
-We apply the linear transformation of the last dimension and splits that into
-the heads
49 head_shape = x.shape[:-1]-
In simple terms, it finds keys that matches the query, and get the values of +
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.
-Softmax is calculate along the axis of of the sequence (or time).
+Softmax is calculated along the axis of of the sequence (or time).
61class MultiHeadAttention(Module):We store attentions so that it can used for logging, or other computations if needed
+We store attentions so that it can be used for logging, or other computations if needed
109 self.attn = Nonequery, key and value are the tensors that store
-collection ofquery, key and value vectors.
+collection of query, key and value vectors.
They have shape [seq_len, batch_size, d_model].
mask has shape [seq_len, seq_len, batch_size] and indicates
+
mask has shape [seq_len, seq_len, batch_size] and
mask[i, j, b] indicates whether for batch b,
query at position i has access to key-value at position j.
i has access to key-value at position j
mask has shape [seq_len, seq_len, batch_size],
where first dimension is the query dimension.
-If the query dimension is equal to $1$ it will be broadcasted
+If the query dimension is equal to $1$ it will be broadcasted.
143 assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
@@ -426,8 +426,8 @@ If the query dimension is equal to $1$ it will be broadcasted
- Prepare query, key and value for attention computation
-These will then have shape [seq_len, batch_size, heads, d_k]
+ Prepare query, key and value for attention computation.
+These will then have shape [seq_len, batch_size, heads, d_k].
150 query = self.query(query)
@@ -440,8 +440,8 @@ These will then have shape [seq_len, batch_size, heads, d_k]
- Compute attention scores $Q K^\top$
-Results in a tensor of shape [seq_len, seq_len, batch_size, heads]
+ Compute attention scores $Q K^\top$.
+This gives a tensor of shape [seq_len, seq_len, batch_size, heads].
156 scores = self.get_scores(query, key)
diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py
index f8c0a945..6e897181 100644
--- a/labml_nn/transformers/mha.py
+++ b/labml_nn/transformers/mha.py
@@ -44,8 +44,8 @@ class PrepareForMultiHeadAttention(Module):
def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
- # We apply the linear transformation of the last dimension and splits that into
- # the heads
+ # We apply the linear transformation to the last dimension and split that into
+ # the heads.
head_shape = x.shape[:-1]
# Linear transform
@@ -66,7 +66,7 @@ class MultiHeadAttention(Module):
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
- In simple terms, it finds keys that matches the query, and get the values of
+ In simple terms, it finds keys that matches the query, and gets the values of
those keys.
It uses dot-product of query and key as the indicator of how matching they are.
@@ -74,7 +74,7 @@ class MultiHeadAttention(Module):
This is done to avoid large dot-product values causing softmax to
give very small gradients when $d_k$ is large.
- Softmax is calculate along the axis of of the sequence (or time).
+ Softmax is calculated along the axis of of the sequence (or time).
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
@@ -105,7 +105,7 @@ class MultiHeadAttention(Module):
# Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k)
- # We store attentions so that it can used for logging, or other computations if needed
+ # We store attentions so that it can be used for logging, or other computations if needed
self.attn = None
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -125,10 +125,10 @@ class MultiHeadAttention(Module):
mask: Optional[torch.Tensor] = None):
"""
`query`, `key` and `value` are the tensors that store
- collection of*query*, *key* and *value* vectors.
+ collection of *query*, *key* and *value* vectors.
They have shape `[seq_len, batch_size, d_model]`.
- `mask` has shape `[seq_len, seq_len, batch_size]` and indicates
+ `mask` has shape `[seq_len, seq_len, batch_size]` and
`mask[i, j, b]` indicates whether for batch `b`,
query at position `i` has access to key-value at position `j`.
"""
@@ -139,20 +139,20 @@ class MultiHeadAttention(Module):
if mask is not None:
# `mask` has shape `[seq_len, seq_len, batch_size]`,
# where first dimension is the query dimension.
- # If the query dimension is equal to $1$ it will be broadcasted
+ # If the query dimension is equal to $1$ it will be broadcasted.
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
# Same mask applied to all heads.
mask = mask.unsqueeze(-1)
- # Prepare `query`, `key` and `value` for attention computation
- # These will then have shape `[seq_len, batch_size, heads, d_k]`
+ # Prepare `query`, `key` and `value` for attention computation.
+ # These will then have shape `[seq_len, batch_size, heads, d_k]`.
query = self.query(query)
key = self.key(key)
value = self.value(value)
- # Compute attention scores $Q K^\top$
- # Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
+ # Compute attention scores $Q K^\top$.
+ # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$