mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import math
 | |
| from typing import Optional
 | |
| 
 | |
| import torch
 | |
| from torch import nn as nn
 | |
| from torch.nn import functional as F
 | |
| 
 | |
| from labml_helpers.module import Module
 | |
| 
 | |
| 
 | |
| class PrepareForMultiHeadAttention(Module):
 | |
|     def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
 | |
|         super().__init__()
 | |
|         self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
 | |
|         self.heads = heads
 | |
|         self.d_k = d_k
 | |
| 
 | |
|     def __call__(self, x: torch.Tensor):
 | |
|         seq_len, batch_size, _ = x.shape
 | |
| 
 | |
|         x = self.linear(x)
 | |
|         x = x.view(seq_len, batch_size, self.heads, self.d_k)
 | |
| 
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class MultiHeadAttention(Module):
 | |
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True):
 | |
|         super().__init__()
 | |
|         self.d_k = d_model // heads
 | |
|         self.heads = heads
 | |
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
 | |
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
 | |
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
 | |
|         self.output = nn.Linear(d_model, d_model)
 | |
|         self.attn = None
 | |
|         self.dropout = nn.Dropout(dropout_prob)
 | |
|         self.scale = 1 / math.sqrt(self.d_k)
 | |
| 
 | |
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor, ):
 | |
|         return torch.einsum('ibhd,jbhd->ijbh', query, key)
 | |
| 
 | |
|     def __call__(self, *,
 | |
|                  query: torch.Tensor,
 | |
|                  key: torch.Tensor,
 | |
|                  value: torch.Tensor,
 | |
|                  mask: Optional[torch.Tensor] = None):
 | |
|         seq_len, batch_size, *_ = query.shape
 | |
| 
 | |
|         if mask is not None:
 | |
|             # mask = ijb
 | |
|             assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
 | |
|             # Same mask applied to all h heads.
 | |
|             mask = mask.unsqueeze(-1)
 | |
| 
 | |
|         query = self.query(query)
 | |
|         key = self.key(key)
 | |
|         value = self.value(value)
 | |
| 
 | |
|         scores = self.get_scores(query, key)
 | |
| 
 | |
|         scores *= self.scale
 | |
|         if mask is not None:
 | |
|             # mask = ijbh
 | |
|             assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
 | |
|             scores = scores.masked_fill(mask == 0, -1e9)
 | |
|         attn = F.softmax(scores, dim=1)
 | |
|         attn = self.dropout(attn)
 | |
| 
 | |
|         x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
 | |
| 
 | |
|         self.attn = attn.detach()
 | |
| 
 | |
|         x = x.reshape(seq_len, batch_size, -1)
 | |
| 
 | |
|         return self.output(x)
 | 
