mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	feedback transformer notes
This commit is contained in:
		@ -22,23 +22,31 @@ class FeedbackAttention(Module):
 | 
			
		||||
    """
 | 
			
		||||
    ## Feedback Attention
 | 
			
		||||
 | 
			
		||||
    This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
 | 
			
		||||
    but with some modifications.
 | 
			
		||||
 | 
			
		||||
    📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
 | 
			
		||||
     or [Multi-Head Attention](../mha.html) to improve readability.
 | 
			
		||||
    This module computes recurrent attention similar to attention from original transformers
 | 
			
		||||
    paper.
 | 
			
		||||
 | 
			
		||||
    $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
 | 
			
		||||
        """
 | 
			
		||||
        * 'heads' is the number of attention heads
 | 
			
		||||
        * `d_model` is the number of features in the transformer
 | 
			
		||||
        * `dropout_prob` is the attention dropout probability
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        # Number of features per head
 | 
			
		||||
        self.d_k = d_model // heads
 | 
			
		||||
        #
 | 
			
		||||
        self.heads = heads
 | 
			
		||||
 | 
			
		||||
        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
 | 
			
		||||
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
 | 
			
		||||
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
 | 
			
		||||
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
 | 
			
		||||
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
 | 
			
		||||
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
 | 
			
		||||
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=True)
 | 
			
		||||
 | 
			
		||||
        # Output layer
 | 
			
		||||
        self.output = nn.Linear(d_model, d_model)
 | 
			
		||||
@ -55,8 +63,6 @@ class FeedbackAttention(Module):
 | 
			
		||||
 | 
			
		||||
        # Relative positional embeddings for key relative to the query.
 | 
			
		||||
        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
 | 
			
		||||
        # Relative positional embedding bias for key relative to the query.
 | 
			
		||||
        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
 | 
			
		||||
        # Positional embeddings for the query is independent of the position of the query
 | 
			
		||||
        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
 | 
			
		||||
 | 
			
		||||
@ -65,47 +71,27 @@ class FeedbackAttention(Module):
 | 
			
		||||
 | 
			
		||||
    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
 | 
			
		||||
        """
 | 
			
		||||
        ### Get relative attention scores
 | 
			
		||||
        ### Get attention scores
 | 
			
		||||
 | 
			
		||||
        \begin{align}
 | 
			
		||||
        A_{j} &= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
 | 
			
		||||
                      &= \color{cyan}{Q^T} \color{lightgreen}{K_j} +
 | 
			
		||||
                         \color{cyan}{Q^T} \color{lightgreen}{U_j} +
 | 
			
		||||
                         \color{cyan}{V^T} \color{lightgreen}{K_j} +
 | 
			
		||||
                         \color{cyan}{V^T} \color{lightgreen}{U_j}
 | 
			
		||||
        A_{j} &== Q^T K_j \\
 | 
			
		||||
            &= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
 | 
			
		||||
            &= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}
 | 
			
		||||
        \end{align}
 | 
			
		||||
 | 
			
		||||
        where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of
 | 
			
		||||
         original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$
 | 
			
		||||
         and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of
 | 
			
		||||
         absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$.
 | 
			
		||||
 | 
			
		||||
        We replace $\color{cyan}{V^T} \color{lightgreen}{U_j}$ with
 | 
			
		||||
        $S_j$.
 | 
			
		||||
 | 
			
		||||
        \begin{align}
 | 
			
		||||
        A^{rel}_{j} &= \underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} +
 | 
			
		||||
                       \underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} +
 | 
			
		||||
                       \underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}} +
 | 
			
		||||
                       \underset{\mathbf{D}}{\color{orange}{S_j}}
 | 
			
		||||
        \end{align}
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # $\color{lightgreen}{U_j}$
 | 
			
		||||
        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
 | 
			
		||||
        # $\color{orange}{S_j}$
 | 
			
		||||
        key_pos_bias = self.key_pos_bias[-key.shape[0]:]
 | 
			
		||||
        # $\color{cyan}{V^T}$
 | 
			
		||||
        query_pos_bias = self.query_pos_bias[None, :, :]
 | 
			
		||||
 | 
			
		||||
        # $\underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} +
 | 
			
		||||
        # \underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}}$
 | 
			
		||||
        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
 | 
			
		||||
        # $\underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} +
 | 
			
		||||
        # \underset{\mathbf{D}}{\color{orange}{S_j}}$
 | 
			
		||||
        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
 | 
			
		||||
 | 
			
		||||
        return ac + bd
 | 
			
		||||
        # $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$
 | 
			
		||||
        return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
 | 
			
		||||
 | 
			
		||||
    def __call__(self, *,
 | 
			
		||||
                 query: torch.Tensor,
 | 
			
		||||
@ -147,22 +133,39 @@ class FeedbackAttention(Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FeedbackTransformerLayer(Module):
 | 
			
		||||
    """
 | 
			
		||||
    ## Feedback Transformer Layer
 | 
			
		||||
 | 
			
		||||
    This implements a single transformer layer in the feedback transformer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *,
 | 
			
		||||
                 d_model: int,
 | 
			
		||||
                 attn: FeedbackAttention,
 | 
			
		||||
                 feed_forward: FeedForward,
 | 
			
		||||
                 dropout_prob: float):
 | 
			
		||||
        """
 | 
			
		||||
        * `d_model` is the number of features in the transformer
 | 
			
		||||
        * `attn` is the feedback attention module
 | 
			
		||||
        * `feed_forward` is the position-wise feed forward layer
 | 
			
		||||
        * `dropout_prob` is the dropout probability for dropout layers after attention and feed-forward
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Transformer size $d_{model}$
 | 
			
		||||
        self.size = d_model
 | 
			
		||||
        #
 | 
			
		||||
        self.attn = attn
 | 
			
		||||
        self.feed_forward = feed_forward
 | 
			
		||||
        self.dropout = nn.Dropout(dropout_prob)
 | 
			
		||||
 | 
			
		||||
        # Normalization layers
 | 
			
		||||
        self.norm_self_attn = nn.LayerNorm([d_model])
 | 
			
		||||
        self.norm_ff = nn.LayerNorm([d_model])
 | 
			
		||||
 | 
			
		||||
    def __call__(self, *,
 | 
			
		||||
                 x: torch.Tensor,
 | 
			
		||||
                 mem: Optional[torch.Tensor]):
 | 
			
		||||
        # If there is memory
 | 
			
		||||
        if mem is not None:
 | 
			
		||||
            # Normalize the vectors before doing self attention
 | 
			
		||||
            z = self.norm_self_attn(x)
 | 
			
		||||
@ -178,49 +181,66 @@ class FeedbackTransformerLayer(Module):
 | 
			
		||||
        # Add the feed-forward results back
 | 
			
		||||
        x = x + self.dropout(ff)
 | 
			
		||||
 | 
			
		||||
        #
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FeedbackTransformer(Module):
 | 
			
		||||
    """
 | 
			
		||||
    ## Transformer Encoder
 | 
			
		||||
    ## Feedback Transformer Module
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
 | 
			
		||||
        """
 | 
			
		||||
        * `layer` is the feedback transformer layer, which we clone for each layer
 | 
			
		||||
        * `n_layers` is the number of layers in the transformer
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Make copies of the transformer layer
 | 
			
		||||
        self.layers = clone_module_list(layer, n_layers)
 | 
			
		||||
        # Final normalization layer
 | 
			
		||||
        self.norm = nn.LayerNorm([layer.size])
 | 
			
		||||
        #
 | 
			
		||||
        # Memory vectors are computed as a weighted sum of representations of each layer.
 | 
			
		||||
        # This is the weights parameter for that.
 | 
			
		||||
        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
 | 
			
		||||
        #
 | 
			
		||||
        # Softmax for weights before taking the weighted sum
 | 
			
		||||
        self.softmax = nn.Softmax(0)
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x_seq: torch.Tensor):
 | 
			
		||||
        """
 | 
			
		||||
        * `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Split the input to a list along the sequence axis
 | 
			
		||||
        x_seq = torch.unbind(x_seq, dim=0)
 | 
			
		||||
        # List to store the outputs
 | 
			
		||||
        res = []
 | 
			
		||||
        # List to store the memory vectors
 | 
			
		||||
        mem = []
 | 
			
		||||
        # For each input step
 | 
			
		||||
        for x in x_seq:
 | 
			
		||||
            # List of embeddings from each layer
 | 
			
		||||
            emb = [x]
 | 
			
		||||
            # List to store layer outputs
 | 
			
		||||
            layer_outputs = [x]
 | 
			
		||||
 | 
			
		||||
            # If there is memory, stack them into a vector
 | 
			
		||||
            mem_tensor = torch.stack(mem) if mem else None
 | 
			
		||||
 | 
			
		||||
            # Run through each layer
 | 
			
		||||
            for layer in self.layers:
 | 
			
		||||
                # Get layer output
 | 
			
		||||
                x = layer(x=x, mem=mem_tensor)
 | 
			
		||||
                emb.append(x)
 | 
			
		||||
                # Append them to the list of layer outputs
 | 
			
		||||
                layer_outputs.append(x)
 | 
			
		||||
 | 
			
		||||
            # Stack embeddings
 | 
			
		||||
            emb = torch.stack(emb)
 | 
			
		||||
            # Weighted sum of embeddings
 | 
			
		||||
            mem.append(torch.einsum('lbd,l->bd', emb, self.softmax(self.weights)))
 | 
			
		||||
            # Stack the layer outputs to a tensor
 | 
			
		||||
            layer_outputs = torch.stack(layer_outputs)
 | 
			
		||||
            # Calculate the memory vector as a weighted sum of layer outputs
 | 
			
		||||
            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
 | 
			
		||||
            # Append the output to results
 | 
			
		||||
            res.append(x)
 | 
			
		||||
 | 
			
		||||
        # Stack the output tensors
 | 
			
		||||
        res = torch.stack(res)
 | 
			
		||||
        # Normalize
 | 
			
		||||
        # Normalize the output
 | 
			
		||||
        return self.norm(res)
 | 
			
		||||
 | 
			
		||||
@ -91,9 +91,9 @@ class MultiHeadAttention(Module):
 | 
			
		||||
        self.heads = heads
 | 
			
		||||
 | 
			
		||||
        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
 | 
			
		||||
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
 | 
			
		||||
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
 | 
			
		||||
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
 | 
			
		||||
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=bias)
 | 
			
		||||
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=bias)
 | 
			
		||||
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
 | 
			
		||||
 | 
			
		||||
        # Softmax for attention along the time dimension of `key`
 | 
			
		||||
        self.softmax = nn.Softmax(dim=1)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user