From f346824200499ceb21447f37e4d33e91e7b9b50f Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 21 Aug 2025 12:17:54 +0530 Subject: [PATCH] cleanup jax --- .../transformers/jax_transformer/__init__.py | 55 ++++--------------- 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/labml_nn/transformers/jax_transformer/__init__.py b/labml_nn/transformers/jax_transformer/__init__.py index 97aa9da5..fb9b1bf5 100644 --- a/labml_nn/transformers/jax_transformer/__init__.py +++ b/labml_nn/transformers/jax_transformer/__init__.py @@ -422,43 +422,6 @@ class LayerNorm(Module): return x_norm -class PrepareForMultiHeadAttention(Module): - """ - - - ## Prepare for multi-head attention - - This module does a linear transformation and splits the vector into given - number of heads for multi-head attention. - This is used to transform **key**, **query**, and **value** vectors. - """ - - def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int): - super().__init__() - # Linear layer for linear transform - self.linear = Linear(rnd_key, d_model, heads * d_k) - # Number of heads - self.heads = heads - # Number of dimensions in vectors in each head - self.d_k = d_k - - def __call__(self, x: jnp.ndarray): - # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. - # We apply the linear transformation to the last dimension and split that into - # the heads. - head_shape = x.shape[:-1] - - # Linear transform - x = self.linear(x) - - # Split last dimension into heads - - x = x.reshape(*head_shape, self.heads, self.d_k) - - # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` - return x - - class MultiHeadAttention(Module): r""" @@ -503,9 +466,9 @@ class MultiHeadAttention(Module): self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k) - self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k) - self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k) + self.query = Linear(rnd_keys[0], d_model, d_model) + self.key = Linear(rnd_keys[1], d_model, d_model) + self.value = Linear(rnd_keys[2], d_model, d_model) # Output layer self.output = Linear(rnd_keys[3], d_model, d_model) @@ -537,12 +500,18 @@ class MultiHeadAttention(Module): # Same mask applied to all heads. mask = mask[:, :, None] - # Prepare `query`, `key` and `value` for attention computation. - # These will then have shape `[seq_len, heads, d_k]`. + # Apply linear transformations query = self.query(query) key = self.key(key) value = self.value(value) + # Reshape to split into heads + # Input has shape `[seq_len, batch_size, d_model]`. + # We split the last dimension into `heads` and `d_k`. + query = query.reshape(*query.shape[:-1], self.heads, self.d_k) + key = key.reshape(*key.shape[:-1], self.heads, self.d_k) + value = value.reshape(*value.shape[:-1], self.heads, self.d_k) + # Compute attention scores $Q K^\top$. # This gives a tensor of shape `[seq_len, seq_len, heads]`. # $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$ @@ -1038,4 +1007,4 @@ def main(): # if __name__ == '__main__': - main() + main() \ No newline at end of file