mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-26 19:46:20 +08:00
cleanup jax
This commit is contained in:
@ -422,43 +422,6 @@ class LayerNorm(Module):
|
||||
return x_norm
|
||||
|
||||
|
||||
class PrepareForMultiHeadAttention(Module):
|
||||
"""
|
||||
<a id="PrepareMHA"></a>
|
||||
|
||||
## Prepare for multi-head attention
|
||||
|
||||
This module does a linear transformation and splits the vector into given
|
||||
number of heads for multi-head attention.
|
||||
This is used to transform **key**, **query**, and **value** vectors.
|
||||
"""
|
||||
|
||||
def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
|
||||
super().__init__()
|
||||
# Linear layer for linear transform
|
||||
self.linear = Linear(rnd_key, d_model, heads * d_k)
|
||||
# Number of heads
|
||||
self.heads = heads
|
||||
# Number of dimensions in vectors in each head
|
||||
self.d_k = d_k
|
||||
|
||||
def __call__(self, x: jnp.ndarray):
|
||||
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
||||
# We apply the linear transformation to the last dimension and split that into
|
||||
# the heads.
|
||||
head_shape = x.shape[:-1]
|
||||
|
||||
# Linear transform
|
||||
x = self.linear(x)
|
||||
|
||||
# Split last dimension into heads
|
||||
|
||||
x = x.reshape(*head_shape, self.heads, self.d_k)
|
||||
|
||||
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
r"""
|
||||
<a id="MHA"></a>
|
||||
@ -503,9 +466,9 @@ class MultiHeadAttention(Module):
|
||||
self.heads = heads
|
||||
|
||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||
self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
|
||||
self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
|
||||
self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)
|
||||
self.query = Linear(rnd_keys[0], d_model, d_model)
|
||||
self.key = Linear(rnd_keys[1], d_model, d_model)
|
||||
self.value = Linear(rnd_keys[2], d_model, d_model)
|
||||
|
||||
# Output layer
|
||||
self.output = Linear(rnd_keys[3], d_model, d_model)
|
||||
@ -537,12 +500,18 @@ class MultiHeadAttention(Module):
|
||||
# Same mask applied to all heads.
|
||||
mask = mask[:, :, None]
|
||||
|
||||
# Prepare `query`, `key` and `value` for attention computation.
|
||||
# These will then have shape `[seq_len, heads, d_k]`.
|
||||
# Apply linear transformations
|
||||
query = self.query(query)
|
||||
key = self.key(key)
|
||||
value = self.value(value)
|
||||
|
||||
# Reshape to split into heads
|
||||
# Input has shape `[seq_len, batch_size, d_model]`.
|
||||
# We split the last dimension into `heads` and `d_k`.
|
||||
query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
|
||||
key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
|
||||
value = value.reshape(*value.shape[:-1], self.heads, self.d_k)
|
||||
|
||||
# Compute attention scores $Q K^\top$.
|
||||
# This gives a tensor of shape `[seq_len, seq_len, heads]`.
|
||||
# $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$
|
||||
@ -1038,4 +1007,4 @@ def main():
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
Reference in New Issue
Block a user