mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from labml_helpers.module import Module
|
|
|
|
|
|
class PrepareForMultiHeadAttention(Module):
|
|
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
|
|
super().__init__()
|
|
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
|
|
self.heads = heads
|
|
self.d_k = d_k
|
|
|
|
def __call__(self, x: torch.Tensor):
|
|
seq_len, batch_size, _ = x.shape
|
|
|
|
x = self.linear(x)
|
|
x = x.view(seq_len, batch_size, self.heads, self.d_k)
|
|
|
|
return x
|
|
|
|
|
|
class MultiHeadAttention(Module):
|
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True):
|
|
super().__init__()
|
|
self.d_k = d_model // heads
|
|
self.heads = heads
|
|
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
|
self.output = nn.Linear(d_model, d_model)
|
|
self.attn = None
|
|
self.dropout = nn.Dropout(dropout_prob)
|
|
self.scale = 1 / math.sqrt(self.d_k)
|
|
|
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor, ):
|
|
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
|
|
|
def __call__(self, *,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None):
|
|
seq_len, batch_size, *_ = query.shape
|
|
|
|
if mask is not None:
|
|
# mask = ijb
|
|
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
|
|
# Same mask applied to all h heads.
|
|
mask = mask.unsqueeze(-1)
|
|
|
|
query = self.query(query)
|
|
key = self.key(key)
|
|
value = self.value(value)
|
|
|
|
scores = self.get_scores(query, key)
|
|
|
|
scores *= self.scale
|
|
if mask is not None:
|
|
# mask = ijbh
|
|
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
|
|
scores = scores.masked_fill(mask == 0, -1e9)
|
|
attn = F.softmax(scores, dim=1)
|
|
attn = self.dropout(attn)
|
|
|
|
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
|
|
|
self.attn = attn.detach()
|
|
|
|
x = x.reshape(seq_len, batch_size, -1)
|
|
|
|
return self.output(x)
|