diff --git a/labml_nn/transformers/jax_transformer/__init__.py b/labml_nn/transformers/jax_transformer/__init__.py
index 97aa9da5..fb9b1bf5 100644
--- a/labml_nn/transformers/jax_transformer/__init__.py
+++ b/labml_nn/transformers/jax_transformer/__init__.py
@@ -422,43 +422,6 @@ class LayerNorm(Module):
return x_norm
-class PrepareForMultiHeadAttention(Module):
- """
-
-
- ## Prepare for multi-head attention
-
- This module does a linear transformation and splits the vector into given
- number of heads for multi-head attention.
- This is used to transform **key**, **query**, and **value** vectors.
- """
-
- def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
- super().__init__()
- # Linear layer for linear transform
- self.linear = Linear(rnd_key, d_model, heads * d_k)
- # Number of heads
- self.heads = heads
- # Number of dimensions in vectors in each head
- self.d_k = d_k
-
- def __call__(self, x: jnp.ndarray):
- # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
- # We apply the linear transformation to the last dimension and split that into
- # the heads.
- head_shape = x.shape[:-1]
-
- # Linear transform
- x = self.linear(x)
-
- # Split last dimension into heads
-
- x = x.reshape(*head_shape, self.heads, self.d_k)
-
- # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
- return x
-
-
class MultiHeadAttention(Module):
r"""
@@ -503,9 +466,9 @@ class MultiHeadAttention(Module):
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
- self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
- self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
- self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)
+ self.query = Linear(rnd_keys[0], d_model, d_model)
+ self.key = Linear(rnd_keys[1], d_model, d_model)
+ self.value = Linear(rnd_keys[2], d_model, d_model)
# Output layer
self.output = Linear(rnd_keys[3], d_model, d_model)
@@ -537,12 +500,18 @@ class MultiHeadAttention(Module):
# Same mask applied to all heads.
mask = mask[:, :, None]
- # Prepare `query`, `key` and `value` for attention computation.
- # These will then have shape `[seq_len, heads, d_k]`.
+ # Apply linear transformations
query = self.query(query)
key = self.key(key)
value = self.value(value)
+ # Reshape to split into heads
+ # Input has shape `[seq_len, batch_size, d_model]`.
+ # We split the last dimension into `heads` and `d_k`.
+ query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
+ key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
+ value = value.reshape(*value.shape[:-1], self.heads, self.d_k)
+
# Compute attention scores $Q K^\top$.
# This gives a tensor of shape `[seq_len, seq_len, heads]`.
# $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$
@@ -1038,4 +1007,4 @@ def main():
#
if __name__ == '__main__':
- main()
+ main()
\ No newline at end of file