mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-26 19:46:20 +08:00 
			
		
		
		
	cleanup jax
This commit is contained in:
		| @ -422,43 +422,6 @@ class LayerNorm(Module): | ||||
|         return x_norm | ||||
|  | ||||
|  | ||||
| class PrepareForMultiHeadAttention(Module): | ||||
|     """ | ||||
|     <a id="PrepareMHA"></a> | ||||
|  | ||||
|     ## Prepare for multi-head attention | ||||
|  | ||||
|     This module does a linear transformation and splits the vector into given | ||||
|     number of heads for multi-head attention. | ||||
|     This is used to transform **key**, **query**, and **value** vectors. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int): | ||||
|         super().__init__() | ||||
|         # Linear layer for linear transform | ||||
|         self.linear = Linear(rnd_key, d_model, heads * d_k) | ||||
|         # Number of heads | ||||
|         self.heads = heads | ||||
|         # Number of dimensions in vectors in each head | ||||
|         self.d_k = d_k | ||||
|  | ||||
|     def __call__(self, x: jnp.ndarray): | ||||
|         # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. | ||||
|         # We apply the linear transformation to the last dimension and split that into | ||||
|         # the heads. | ||||
|         head_shape = x.shape[:-1] | ||||
|  | ||||
|         # Linear transform | ||||
|         x = self.linear(x) | ||||
|  | ||||
|         # Split last dimension into heads | ||||
|  | ||||
|         x = x.reshape(*head_shape, self.heads, self.d_k) | ||||
|  | ||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class MultiHeadAttention(Module): | ||||
|     r""" | ||||
|     <a id="MHA"></a> | ||||
| @ -503,9 +466,9 @@ class MultiHeadAttention(Module): | ||||
|         self.heads = heads | ||||
|  | ||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||
|         self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k) | ||||
|         self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k) | ||||
|         self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k) | ||||
|         self.query = Linear(rnd_keys[0], d_model, d_model) | ||||
|         self.key = Linear(rnd_keys[1], d_model, d_model) | ||||
|         self.value = Linear(rnd_keys[2], d_model, d_model) | ||||
|  | ||||
|         # Output layer | ||||
|         self.output = Linear(rnd_keys[3], d_model, d_model) | ||||
| @ -537,12 +500,18 @@ class MultiHeadAttention(Module): | ||||
|             # Same mask applied to all heads. | ||||
|             mask = mask[:, :, None] | ||||
|  | ||||
|         # Prepare `query`, `key` and `value` for attention computation. | ||||
|         # These will then have shape `[seq_len, heads, d_k]`. | ||||
|         # Apply linear transformations | ||||
|         query = self.query(query) | ||||
|         key = self.key(key) | ||||
|         value = self.value(value) | ||||
|  | ||||
|         # Reshape to split into heads | ||||
|         # Input has shape `[seq_len, batch_size, d_model]`. | ||||
|         # We split the last dimension into `heads` and `d_k`. | ||||
|         query = query.reshape(*query.shape[:-1], self.heads, self.d_k) | ||||
|         key = key.reshape(*key.shape[:-1], self.heads, self.d_k) | ||||
|         value = value.reshape(*value.shape[:-1], self.heads, self.d_k) | ||||
|  | ||||
|         # Compute attention scores $Q K^\top$. | ||||
|         # This gives a tensor of shape `[seq_len, seq_len, heads]`. | ||||
|         # $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$ | ||||
| @ -1038,4 +1007,4 @@ def main(): | ||||
|  | ||||
| # | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|     main() | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri