cleanup jax

This commit is contained in:
Varuna Jayasiri
2025-08-21 12:17:54 +05:30
parent 96f7b5a8e1
commit f346824200

View File

@ -422,43 +422,6 @@ class LayerNorm(Module):
return x_norm
class PrepareForMultiHeadAttention(Module):
"""
<a id="PrepareMHA"></a>
## Prepare for multi-head attention
This module does a linear transformation and splits the vector into given
number of heads for multi-head attention.
This is used to transform **key**, **query**, and **value** vectors.
"""
def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
super().__init__()
# Linear layer for linear transform
self.linear = Linear(rnd_key, d_model, heads * d_k)
# Number of heads
self.heads = heads
# Number of dimensions in vectors in each head
self.d_k = d_k
def __call__(self, x: jnp.ndarray):
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
# We apply the linear transformation to the last dimension and split that into
# the heads.
head_shape = x.shape[:-1]
# Linear transform
x = self.linear(x)
# Split last dimension into heads
x = x.reshape(*head_shape, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
return x
class MultiHeadAttention(Module):
r"""
<a id="MHA"></a>
@ -503,9 +466,9 @@ class MultiHeadAttention(Module):
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)
self.query = Linear(rnd_keys[0], d_model, d_model)
self.key = Linear(rnd_keys[1], d_model, d_model)
self.value = Linear(rnd_keys[2], d_model, d_model)
# Output layer
self.output = Linear(rnd_keys[3], d_model, d_model)
@ -537,12 +500,18 @@ class MultiHeadAttention(Module):
# Same mask applied to all heads.
mask = mask[:, :, None]
# Prepare `query`, `key` and `value` for attention computation.
# These will then have shape `[seq_len, heads, d_k]`.
# Apply linear transformations
query = self.query(query)
key = self.key(key)
value = self.value(value)
# Reshape to split into heads
# Input has shape `[seq_len, batch_size, d_model]`.
# We split the last dimension into `heads` and `d_k`.
query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
value = value.reshape(*value.shape[:-1], self.heads, self.d_k)
# Compute attention scores $Q K^\top$.
# This gives a tensor of shape `[seq_len, seq_len, heads]`.
# $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$