mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-27 03:56:48 +08:00
cleanup jax
This commit is contained in:
@ -422,43 +422,6 @@ class LayerNorm(Module):
|
|||||||
return x_norm
|
return x_norm
|
||||||
|
|
||||||
|
|
||||||
class PrepareForMultiHeadAttention(Module):
|
|
||||||
"""
|
|
||||||
<a id="PrepareMHA"></a>
|
|
||||||
|
|
||||||
## Prepare for multi-head attention
|
|
||||||
|
|
||||||
This module does a linear transformation and splits the vector into given
|
|
||||||
number of heads for multi-head attention.
|
|
||||||
This is used to transform **key**, **query**, and **value** vectors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
|
|
||||||
super().__init__()
|
|
||||||
# Linear layer for linear transform
|
|
||||||
self.linear = Linear(rnd_key, d_model, heads * d_k)
|
|
||||||
# Number of heads
|
|
||||||
self.heads = heads
|
|
||||||
# Number of dimensions in vectors in each head
|
|
||||||
self.d_k = d_k
|
|
||||||
|
|
||||||
def __call__(self, x: jnp.ndarray):
|
|
||||||
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
|
||||||
# We apply the linear transformation to the last dimension and split that into
|
|
||||||
# the heads.
|
|
||||||
head_shape = x.shape[:-1]
|
|
||||||
|
|
||||||
# Linear transform
|
|
||||||
x = self.linear(x)
|
|
||||||
|
|
||||||
# Split last dimension into heads
|
|
||||||
|
|
||||||
x = x.reshape(*head_shape, self.heads, self.d_k)
|
|
||||||
|
|
||||||
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(Module):
|
class MultiHeadAttention(Module):
|
||||||
r"""
|
r"""
|
||||||
<a id="MHA"></a>
|
<a id="MHA"></a>
|
||||||
@ -503,9 +466,9 @@ class MultiHeadAttention(Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||||
self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
|
self.query = Linear(rnd_keys[0], d_model, d_model)
|
||||||
self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
|
self.key = Linear(rnd_keys[1], d_model, d_model)
|
||||||
self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)
|
self.value = Linear(rnd_keys[2], d_model, d_model)
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
self.output = Linear(rnd_keys[3], d_model, d_model)
|
self.output = Linear(rnd_keys[3], d_model, d_model)
|
||||||
@ -537,12 +500,18 @@ class MultiHeadAttention(Module):
|
|||||||
# Same mask applied to all heads.
|
# Same mask applied to all heads.
|
||||||
mask = mask[:, :, None]
|
mask = mask[:, :, None]
|
||||||
|
|
||||||
# Prepare `query`, `key` and `value` for attention computation.
|
# Apply linear transformations
|
||||||
# These will then have shape `[seq_len, heads, d_k]`.
|
|
||||||
query = self.query(query)
|
query = self.query(query)
|
||||||
key = self.key(key)
|
key = self.key(key)
|
||||||
value = self.value(value)
|
value = self.value(value)
|
||||||
|
|
||||||
|
# Reshape to split into heads
|
||||||
|
# Input has shape `[seq_len, batch_size, d_model]`.
|
||||||
|
# We split the last dimension into `heads` and `d_k`.
|
||||||
|
query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
|
||||||
|
key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
|
||||||
|
value = value.reshape(*value.shape[:-1], self.heads, self.d_k)
|
||||||
|
|
||||||
# Compute attention scores $Q K^\top$.
|
# Compute attention scores $Q K^\top$.
|
||||||
# This gives a tensor of shape `[seq_len, seq_len, heads]`.
|
# This gives a tensor of shape `[seq_len, seq_len, heads]`.
|
||||||
# $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$
|
# $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$
|
||||||
@ -1038,4 +1007,4 @@ def main():
|
|||||||
|
|
||||||
#
|
#
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user