mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-01 03:43:09 +08:00 
			
		
		
		
	mha explanation
This commit is contained in:
		| @ -8,7 +8,10 @@ summary: > | ||||
|  | ||||
| # Transformers | ||||
|  | ||||
| ## Transformer Building Blocks | ||||
| This module contains [PyTorch](https://pytorch.org/) | ||||
| implementations and explanations of original transformer | ||||
| from paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762), | ||||
| and derivatives and enhancements of it. | ||||
|  | ||||
| * [Multi-head attention](mha.html) | ||||
| * [Relative multi-head attention](relative_mha.html) | ||||
|  | ||||
| @ -8,6 +8,9 @@ summary: > | ||||
|  | ||||
| # Multi-Headed Attention | ||||
|  | ||||
| This is a tutorial/implementation of multi-headed attention | ||||
| from paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) | ||||
| in [PyTorch](https://pytorch.org/). | ||||
| The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) | ||||
| """ | ||||
|  | ||||
| @ -23,21 +26,29 @@ from torch.nn import functional as F | ||||
|  | ||||
| class PrepareForMultiHeadAttention(Module): | ||||
|     """ | ||||
|     ## Prepare for multi-head attention | ||||
|  | ||||
|     This module does a linear transformation and splits the vector into given | ||||
|     number of heads for multi-head attention. | ||||
|     This is used to transform **key**, **query**, and **value** vectors. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): | ||||
|         super().__init__() | ||||
|         # Linear layer for linear transform | ||||
|         self.linear = nn.Linear(d_model, heads * d_k, bias=bias) | ||||
|         # Number of heads | ||||
|         self.heads = heads | ||||
|         # Number of dimensions in vectors in each head | ||||
|         self.d_k = d_k | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         # Input has shape `[seq_len, batch_size, d_model]` | ||||
|         seq_len, batch_size, _ = x.shape | ||||
|  | ||||
|         # Linear transform | ||||
|         x = self.linear(x) | ||||
|         # Split into heads | ||||
|         x = x.view(seq_len, batch_size, self.heads, self.d_k) | ||||
|  | ||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` | ||||
| @ -49,26 +60,38 @@ class MultiHeadAttention(Module): | ||||
|         """ | ||||
|         ## Multi-Head Attention Module | ||||
|  | ||||
|         This computes multi-headed attention for given `query`, `key` and `value` vectors. | ||||
|         `heads` is the number of heads. | ||||
|         `d_model` is the number of features in the `query`, `key` and `value` vectors. | ||||
|         * `heads` is the number of heads. | ||||
|         * `d_model` is the number of features in the `query`, `key` and `value` vectors. | ||||
|  | ||||
|         $$Attention(Q, K, V) = softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ | ||||
|         This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. | ||||
|  | ||||
|         $$Attention(Q, K, V) = \underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ | ||||
|  | ||||
|         In simple terms, it finds keys that matches the query, and get the values of | ||||
|          those keys. | ||||
|  | ||||
|         It uses dot-product of query and key as the indicator of how matching they are. | ||||
|         Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. | ||||
|         This is done to avoid large dot-product values causing softmax to | ||||
|         give very small gradients when $d_k$ is large. | ||||
|  | ||||
|         Softmax is calculate along the axis of of the sequence (or time). | ||||
|         """ | ||||
|  | ||||
|         super().__init__() | ||||
|         self.d_k = d_model // heads | ||||
|         self.heads = heads | ||||
|  | ||||
|         # These transformer the `query`, `key` and `value` vectors for multi-headed attention/ | ||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|  | ||||
|         # Output layer | ||||
|         self.output = nn.Linear(d_model, d_model) | ||||
|         # Dropout | ||||
|         self.dropout = nn.Dropout(dropout_prob) | ||||
|         # Scaling factor before the softmax | ||||
|         self.scale = 1 / math.sqrt(self.d_k) | ||||
|  | ||||
|         # We store attentions so that it can used for logging, or other computations if needed | ||||
| @ -76,12 +99,12 @@ class MultiHeadAttention(Module): | ||||
|  | ||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||
|         """ | ||||
|         ### Calculate scores between queries and keys. | ||||
|         ### Calculate scores between queries and keys | ||||
|  | ||||
|         This method can be overriden for other variations like relative attention. | ||||
|         This method can be overridden for other variations like relative attention. | ||||
|         """ | ||||
|  | ||||
|         # Calculate $Q K^T$ | ||||
|         # Calculate $Q K^T$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$ | ||||
|         return torch.einsum('ibhd,jbhd->ijbh', query, key) | ||||
|  | ||||
|     def __call__(self, *, | ||||
| @ -89,6 +112,16 @@ class MultiHeadAttention(Module): | ||||
|                  key: torch.Tensor, | ||||
|                  value: torch.Tensor, | ||||
|                  mask: Optional[torch.Tensor] = None): | ||||
|         """ | ||||
|         `query`, `key` and `value` are the tensors that store | ||||
|         collection of*query*, *key* and *value* vectors. | ||||
|         They have shape `[seq_len, batch_size, d_model]`. | ||||
|  | ||||
|         `mask` has shape `[seq_len, seq_len, batch_size]` and indicates | ||||
|         `mask[i, j, b]` indicates whether for batch `b`, | ||||
|         query at position `i` has access to key-value at position `j`. | ||||
|         """ | ||||
|  | ||||
|         # `query`, `key` and `value`  have shape `[seq_len, batch_size, d_model]` | ||||
|         seq_len, batch_size, _ = query.shape | ||||
|  | ||||
| @ -118,7 +151,8 @@ class MultiHeadAttention(Module): | ||||
|         if mask is not None: | ||||
|             scores = scores.masked_fill(mask == 0, -1e9) | ||||
|  | ||||
|         # $softmax$ attention $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ | ||||
|         # $softmax$ attention along the key sequence dimension | ||||
|         # $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$$ | ||||
|         attn = F.softmax(scores, dim=1) | ||||
|  | ||||
|         # Save attentions if debugging | ||||
| @ -127,7 +161,8 @@ class MultiHeadAttention(Module): | ||||
|         # Apply dropout | ||||
|         attn = self.dropout(attn) | ||||
|  | ||||
|         # Multiply by values $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$ | ||||
|         # Multiply by values | ||||
|         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ | ||||
|         x = torch.einsum("ijbh,jbhd->ibhd", attn, value) | ||||
|  | ||||
|         # Save attentions for any other calculations  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri