diff --git a/labml_nn/transformers/feedback/__init__.py b/labml_nn/transformers/feedback/__init__.py index 194b98bb..ad913c4a 100644 --- a/labml_nn/transformers/feedback/__init__.py +++ b/labml_nn/transformers/feedback/__init__.py @@ -22,23 +22,31 @@ class FeedbackAttention(Module): """ ## Feedback Attention - This is very similar to [Relative Multi-Head Attention](../relative_mha.html) - but with some modifications. - 📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html) - or [Multi-Head Attention](../mha.html) to improve readability. + This module computes recurrent attention similar to attention from original transformers + paper. + + $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$ """ def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + """ + * 'heads' is the number of attention heads + * `d_model` is the number of features in the transformer + * `dropout_prob` is the attention dropout probability + """ + super().__init__() + # Number of features per head self.d_k = d_model // heads + # self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) - self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) - self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) + self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) # Output layer self.output = nn.Linear(d_model, d_model) @@ -55,8 +63,6 @@ class FeedbackAttention(Module): # Relative positional embeddings for key relative to the query. self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) - # Relative positional embedding bias for key relative to the query. - self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) # Positional embeddings for the query is independent of the position of the query self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) @@ -65,47 +71,27 @@ class FeedbackAttention(Module): def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ - ### Get relative attention scores + ### Get attention scores \begin{align} - A_{j} &= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ - &= \color{cyan}{Q^T} \color{lightgreen}{K_j} + - \color{cyan}{Q^T} \color{lightgreen}{U_j} + - \color{cyan}{V^T} \color{lightgreen}{K_j} + - \color{cyan}{V^T} \color{lightgreen}{U_j} + A_{j} &== Q^T K_j \\ + &= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ + &= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)} \end{align} where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$ and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$. - - We replace $\color{cyan}{V^T} \color{lightgreen}{U_j}$ with - $S_j$. - - \begin{align} - A^{rel}_{j} &= \underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} + - \underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} + - \underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}} + - \underset{\mathbf{D}}{\color{orange}{S_j}} - \end{align} """ # $\color{lightgreen}{U_j}$ key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] - # $\color{orange}{S_j}$ - key_pos_bias = self.key_pos_bias[-key.shape[0]:] # $\color{cyan}{V^T}$ query_pos_bias = self.query_pos_bias[None, :, :] - # $\underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} + - # \underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}}$ - ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key) - # $\underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} + - # \underset{\mathbf{D}}{\color{orange}{S_j}}$ - bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :] - - return ac + bd + # $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$ + return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :]) def __call__(self, *, query: torch.Tensor, @@ -147,22 +133,39 @@ class FeedbackAttention(Module): class FeedbackTransformerLayer(Module): + """ + ## Feedback Transformer Layer + + This implements a single transformer layer in the feedback transformer. + """ + def __init__(self, *, d_model: int, attn: FeedbackAttention, feed_forward: FeedForward, dropout_prob: float): + """ + * `d_model` is the number of features in the transformer + * `attn` is the feedback attention module + * `feed_forward` is the position-wise feed forward layer + * `dropout_prob` is the dropout probability for dropout layers after attention and feed-forward + """ super().__init__() + # Transformer size $d_{model}$ self.size = d_model + # self.attn = attn self.feed_forward = feed_forward self.dropout = nn.Dropout(dropout_prob) + + # Normalization layers self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) def __call__(self, *, x: torch.Tensor, mem: Optional[torch.Tensor]): + # If there is memory if mem is not None: # Normalize the vectors before doing self attention z = self.norm_self_attn(x) @@ -178,49 +181,66 @@ class FeedbackTransformerLayer(Module): # Add the feed-forward results back x = x + self.dropout(ff) + # return x class FeedbackTransformer(Module): """ - ## Transformer Encoder + ## Feedback Transformer Module """ def __init__(self, layer: FeedbackTransformerLayer, n_layers: int): + """ + * `layer` is the feedback transformer layer, which we clone for each layer + * `n_layers` is the number of layers in the transformer + """ + super().__init__() # Make copies of the transformer layer self.layers = clone_module_list(layer, n_layers) # Final normalization layer self.norm = nn.LayerNorm([layer.size]) - # + # Memory vectors are computed as a weighted sum of representations of each layer. + # This is the weights parameter for that. self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True) - # + # Softmax for weights before taking the weighted sum self.softmax = nn.Softmax(0) def __call__(self, x_seq: torch.Tensor): + """ + * `x_seq` is the input with shape `[seq_len, batch_size, d_model]` + """ + + # Split the input to a list along the sequence axis x_seq = torch.unbind(x_seq, dim=0) + # List to store the outputs res = [] + # List to store the memory vectors mem = [] # For each input step for x in x_seq: - # List of embeddings from each layer - emb = [x] + # List to store layer outputs + layer_outputs = [x] + # If there is memory, stack them into a vector mem_tensor = torch.stack(mem) if mem else None # Run through each layer for layer in self.layers: + # Get layer output x = layer(x=x, mem=mem_tensor) - emb.append(x) + # Append them to the list of layer outputs + layer_outputs.append(x) - # Stack embeddings - emb = torch.stack(emb) - # Weighted sum of embeddings - mem.append(torch.einsum('lbd,l->bd', emb, self.softmax(self.weights))) + # Stack the layer outputs to a tensor + layer_outputs = torch.stack(layer_outputs) + # Calculate the memory vector as a weighted sum of layer outputs + mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))) # Append the output to results res.append(x) # Stack the output tensors res = torch.stack(res) - # Normalize + # Normalize the output return self.norm(res) diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py index 76f60c26..30c766ac 100644 --- a/labml_nn/transformers/mha.py +++ b/labml_nn/transformers/mha.py @@ -91,9 +91,9 @@ class MultiHeadAttention(Module): self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) - self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) - self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) + self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) # Softmax for attention along the time dimension of `key` self.softmax = nn.Softmax(dim=1)