mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	✨ feedback transformer update
This commit is contained in:
		@ -55,7 +55,8 @@ class FeedbackAttention(Module):
 | 
				
			|||||||
    $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
 | 
					    $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
 | 
					    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
 | 
				
			||||||
 | 
					                 is_kv_precomputed: bool = False):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        * 'heads' is the number of attention heads
 | 
					        * 'heads' is the number of attention heads
 | 
				
			||||||
        * `d_model` is the number of features in the transformer
 | 
					        * `d_model` is the number of features in the transformer
 | 
				
			||||||
@ -70,9 +71,13 @@ class FeedbackAttention(Module):
 | 
				
			|||||||
        self.heads = heads
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
 | 
					        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
 | 
				
			||||||
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
 | 
					        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
 | 
				
			||||||
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
 | 
					        if not is_kv_precomputed:
 | 
				
			||||||
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=True)
 | 
					            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
 | 
				
			||||||
 | 
					            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.key = None
 | 
				
			||||||
 | 
					            self.value = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Output layer
 | 
					        # Output layer
 | 
				
			||||||
        self.output = nn.Linear(d_model, d_model)
 | 
					        self.output = nn.Linear(d_model, d_model)
 | 
				
			||||||
@ -117,7 +122,10 @@ class FeedbackAttention(Module):
 | 
				
			|||||||
        query_pos_bias = self.query_pos_bias[None, :, :]
 | 
					        query_pos_bias = self.query_pos_bias[None, :, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # $(Q + U^Q)^\top(K_j + U^K_j)$
 | 
					        # $(Q + U^Q)^\top(K_j + U^K_j)$
 | 
				
			||||||
        return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
 | 
					        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
 | 
				
			||||||
 | 
					        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ac + bd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, *,
 | 
					    def __call__(self, *,
 | 
				
			||||||
                 query: torch.Tensor,
 | 
					                 query: torch.Tensor,
 | 
				
			||||||
@ -132,8 +140,10 @@ class FeedbackAttention(Module):
 | 
				
			|||||||
        # `key` and `value`  will then have shape `[seq_len, batch_size, heads, d_k]`
 | 
					        # `key` and `value`  will then have shape `[seq_len, batch_size, heads, d_k]`
 | 
				
			||||||
        # and `query` will have shape `[batch_size, heads, d_k]`
 | 
					        # and `query` will have shape `[batch_size, heads, d_k]`
 | 
				
			||||||
        query = self.query(query)
 | 
					        query = self.query(query)
 | 
				
			||||||
        key = self.key(key)
 | 
					        if self.key:
 | 
				
			||||||
        value = self.value(value)
 | 
					            key = self.key(key)
 | 
				
			||||||
 | 
					        if self.value:
 | 
				
			||||||
 | 
					            value = self.value(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Compute attention scores
 | 
					        # Compute attention scores
 | 
				
			||||||
        # Results in a tensor of shape `[seq_len, batch_size, heads]`
 | 
					        # Results in a tensor of shape `[seq_len, batch_size, heads]`
 | 
				
			||||||
@ -190,13 +200,14 @@ class FeedbackTransformerLayer(Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __call__(self, *,
 | 
					    def __call__(self, *,
 | 
				
			||||||
                 x: torch.Tensor,
 | 
					                 x: torch.Tensor,
 | 
				
			||||||
                 mem: Optional[torch.Tensor]):
 | 
					                 key: Optional[torch.Tensor],
 | 
				
			||||||
 | 
					                 value: Optional[torch.Tensor]):
 | 
				
			||||||
        # If there is memory
 | 
					        # If there is memory
 | 
				
			||||||
        if mem is not None:
 | 
					        if key is not None:
 | 
				
			||||||
            # Normalize the vectors before doing self attention
 | 
					            # Normalize the vectors before doing self attention
 | 
				
			||||||
            z = self.norm_self_attn(x)
 | 
					            z = self.norm_self_attn(x)
 | 
				
			||||||
            # Run through self attention, i.e. keys and values are from self
 | 
					            # Run through self attention, i.e. keys and values are from self
 | 
				
			||||||
            self_attn = self.attn(query=z, key=mem, value=mem)
 | 
					            self_attn = self.attn(query=z, key=key, value=value)
 | 
				
			||||||
            # Add the self attention results
 | 
					            # Add the self attention results
 | 
				
			||||||
            x = x + self.dropout(self_attn)
 | 
					            x = x + self.dropout(self_attn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -255,7 +266,7 @@ class FeedbackTransformer(Module):
 | 
				
			|||||||
            # Run through each layer
 | 
					            # Run through each layer
 | 
				
			||||||
            for layer in self.layers:
 | 
					            for layer in self.layers:
 | 
				
			||||||
                # Get layer output
 | 
					                # Get layer output
 | 
				
			||||||
                x = layer(x=x, mem=mem_tensor)
 | 
					                x = layer(x=x, key=mem_tensor, value=mem_tensor)
 | 
				
			||||||
                # Append them to the list of layer outputs
 | 
					                # Append them to the list of layer outputs
 | 
				
			||||||
                layer_outputs.append(x)
 | 
					                layer_outputs.append(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -270,3 +281,121 @@ class FeedbackTransformer(Module):
 | 
				
			|||||||
        res = torch.stack(res)
 | 
					        res = torch.stack(res)
 | 
				
			||||||
        # Normalize the output
 | 
					        # Normalize the output
 | 
				
			||||||
        return self.norm(res)
 | 
					        return self.norm(res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StackFunction(torch.autograd.Function):
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def forward(ctx, memory, memory_grad, last, n):
 | 
				
			||||||
 | 
					        ctx._mem_grad = memory_grad
 | 
				
			||||||
 | 
					        ctx._n = n
 | 
				
			||||||
 | 
					        return memory[:n + 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def backward(ctx, grad_output):
 | 
				
			||||||
 | 
					        n = ctx._n
 | 
				
			||||||
 | 
					        memory_grad = ctx._mem_grad
 | 
				
			||||||
 | 
					        memory_grad[:n + 1] += grad_output
 | 
				
			||||||
 | 
					        return None, None, memory_grad[n], None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Stack:
 | 
				
			||||||
 | 
					    def __init__(self, max_len: int):
 | 
				
			||||||
 | 
					        self.max_len = max_len
 | 
				
			||||||
 | 
					        self.memory = None
 | 
				
			||||||
 | 
					        self.memory_grad = None
 | 
				
			||||||
 | 
					        self.last = None
 | 
				
			||||||
 | 
					        self.n = -1
 | 
				
			||||||
 | 
					        self.last_get_n = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def append(self, n: int, vector: torch.Tensor):
 | 
				
			||||||
 | 
					        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            if self.memory is None or self.memory.shape[1:] != vector.shape:
 | 
				
			||||||
 | 
					                assert n == 0
 | 
				
			||||||
 | 
					                self.memory = vector.new_zeros(self.max_len, *vector.shape, requires_grad=False)
 | 
				
			||||||
 | 
					                self.memory_grad = vector.new_zeros(self.memory.shape, requires_grad=False)
 | 
				
			||||||
 | 
					            elif n == 0:
 | 
				
			||||||
 | 
					                self.memory_grad.fill_(0.)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # memory[n] = vector.detach()
 | 
				
			||||||
 | 
					            self.memory.data[n] = vector.detach()
 | 
				
			||||||
 | 
					            self.n = n
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.last = vector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self):
 | 
				
			||||||
 | 
					        self.last_get_n = self.n
 | 
				
			||||||
 | 
					        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FeedbackTransformerKV(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Feedback Transformer Module
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        * `layer` is the feedback transformer layer, which we clone for each layer
 | 
				
			||||||
 | 
					        * `n_layers` is the number of layers in the transformer
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # Make copies of the transformer layer
 | 
				
			||||||
 | 
					        self.layers = clone_module_list(layer, n_layers)
 | 
				
			||||||
 | 
					        # Final normalization layer
 | 
				
			||||||
 | 
					        self.norm = nn.LayerNorm([layer.size])
 | 
				
			||||||
 | 
					        # Memory vectors are computed as a weighted sum of representations of each layer.
 | 
				
			||||||
 | 
					        # This is the weights parameter for that.
 | 
				
			||||||
 | 
					        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
 | 
				
			||||||
 | 
					        # Softmax for weights before taking the weighted sum
 | 
				
			||||||
 | 
					        self.softmax = nn.Softmax(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        d_k = d_model // heads
 | 
				
			||||||
 | 
					        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
 | 
				
			||||||
 | 
					        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.mem_key = Stack(512)
 | 
				
			||||||
 | 
					        self.mem_value = Stack(512)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x_seq: torch.Tensor):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        * `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Split the input to a list along the sequence axis
 | 
				
			||||||
 | 
					        x_seq = torch.unbind(x_seq, dim=0)
 | 
				
			||||||
 | 
					        # List to store the outputs
 | 
				
			||||||
 | 
					        res = []
 | 
				
			||||||
 | 
					        # For each input step
 | 
				
			||||||
 | 
					        for step, x in enumerate(x_seq):
 | 
				
			||||||
 | 
					            # List to store layer outputs
 | 
				
			||||||
 | 
					            layer_outputs = [x]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # If there is memory, stack them into a vector
 | 
				
			||||||
 | 
					            key_tensor = None
 | 
				
			||||||
 | 
					            value_tensor = None
 | 
				
			||||||
 | 
					            if step > 0:
 | 
				
			||||||
 | 
					                key_tensor = self.mem_key.get()
 | 
				
			||||||
 | 
					                value_tensor = self.mem_value.get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Run through each layer
 | 
				
			||||||
 | 
					            for layer in self.layers:
 | 
				
			||||||
 | 
					                # Get layer output
 | 
				
			||||||
 | 
					                x = layer(x=x, key=key_tensor, value=value_tensor)
 | 
				
			||||||
 | 
					                # Append them to the list of layer outputs
 | 
				
			||||||
 | 
					                layer_outputs.append(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Stack the layer outputs to a tensor
 | 
				
			||||||
 | 
					            layer_outputs = torch.stack(layer_outputs)
 | 
				
			||||||
 | 
					            # Calculate the memory vector as a weighted sum of layer outputs
 | 
				
			||||||
 | 
					            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
 | 
				
			||||||
 | 
					            self.mem_key.append(step, self.key(mem))
 | 
				
			||||||
 | 
					            self.mem_value.append(step, self.value(mem))
 | 
				
			||||||
 | 
					            # Append the output to results
 | 
				
			||||||
 | 
					            res.append(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Stack the output tensors
 | 
				
			||||||
 | 
					        res = torch.stack(res)
 | 
				
			||||||
 | 
					        # Normalize the output
 | 
				
			||||||
 | 
					        return self.norm(res)
 | 
				
			||||||
 | 
				
			|||||||
@ -39,33 +39,33 @@
 | 
				
			|||||||
     "output_type": "stream",
 | 
					     "output_type": "stream",
 | 
				
			||||||
     "text": [
 | 
					     "text": [
 | 
				
			||||||
      "Collecting labml-nn\n",
 | 
					      "Collecting labml-nn\n",
 | 
				
			||||||
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n",
 | 
					      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n",
 | 
				
			||||||
      "\u001b[K     |████████████████████████████████| 112kB 13.7MB/s \n",
 | 
					      "\u001B[K     |████████████████████████████████| 112kB 13.7MB/s \n",
 | 
				
			||||||
      "\u001b[?25hCollecting labml>=0.4.86\n",
 | 
					      "\u001B[?25hCollecting labml>=0.4.86\n",
 | 
				
			||||||
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n",
 | 
					      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n",
 | 
				
			||||||
      "\u001b[K     |████████████████████████████████| 102kB 11.1MB/s \n",
 | 
					      "\u001B[K     |████████████████████████████████| 102kB 11.1MB/s \n",
 | 
				
			||||||
      "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n",
 | 
					      "\u001B[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n",
 | 
				
			||||||
      "Collecting einops\n",
 | 
					      "Collecting einops\n",
 | 
				
			||||||
      "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
 | 
					      "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
 | 
				
			||||||
      "Collecting labml-helpers>=0.4.72\n",
 | 
					      "Collecting labml-helpers>=0.4.72\n",
 | 
				
			||||||
      "  Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n",
 | 
					      "  Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n",
 | 
				
			||||||
      "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n",
 | 
					      "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n",
 | 
				
			||||||
      "Collecting pyyaml>=5.3.1\n",
 | 
					      "Collecting pyyaml>=5.3.1\n",
 | 
				
			||||||
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n",
 | 
					      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n",
 | 
				
			||||||
      "\u001b[K     |████████████████████████████████| 276kB 40.5MB/s \n",
 | 
					      "\u001B[K     |████████████████████████████████| 276kB 40.5MB/s \n",
 | 
				
			||||||
      "\u001b[?25hCollecting gitpython\n",
 | 
					      "\u001B[?25hCollecting gitpython\n",
 | 
				
			||||||
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n",
 | 
					      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n",
 | 
				
			||||||
      "\u001b[K     |████████████████████████████████| 163kB 50.4MB/s \n",
 | 
					      "\u001B[K     |████████████████████████████████| 163kB 50.4MB/s \n",
 | 
				
			||||||
      "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
 | 
					      "\u001B[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
 | 
				
			||||||
      "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n",
 | 
					      "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n",
 | 
				
			||||||
      "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n",
 | 
					      "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n",
 | 
				
			||||||
      "Collecting gitdb<5,>=4.0.1\n",
 | 
					      "Collecting gitdb<5,>=4.0.1\n",
 | 
				
			||||||
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
 | 
					      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
 | 
				
			||||||
      "\u001b[K     |████████████████████████████████| 71kB 12.7MB/s \n",
 | 
					      "\u001B[K     |████████████████████████████████| 71kB 12.7MB/s \n",
 | 
				
			||||||
      "\u001b[?25hCollecting smmap<4,>=3.0.1\n",
 | 
					      "\u001B[?25hCollecting smmap<4,>=3.0.1\n",
 | 
				
			||||||
      "  Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
 | 
					      "  Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
 | 
				
			||||||
      "Building wheels for collected packages: pyyaml\n",
 | 
					      "Building wheels for collected packages: pyyaml\n",
 | 
				
			||||||
      "  Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
 | 
					      "  Building wheel for pyyaml (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
 | 
				
			||||||
      "  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=132f39d02b291cdc60b9eff7c14b051ee0f520f790ac3e3bdaf6823e9bf7fda3\n",
 | 
					      "  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=132f39d02b291cdc60b9eff7c14b051ee0f520f790ac3e3bdaf6823e9bf7fda3\n",
 | 
				
			||||||
      "  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n",
 | 
					      "  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n",
 | 
				
			||||||
      "Successfully built pyyaml\n",
 | 
					      "Successfully built pyyaml\n",
 | 
				
			||||||
@ -98,114 +98,8 @@
 | 
				
			|||||||
   },
 | 
					   },
 | 
				
			||||||
   "outputs": [],
 | 
					   "outputs": [],
 | 
				
			||||||
   "source": [
 | 
					   "source": [
 | 
				
			||||||
    "import torch\n",
 | 
					 | 
				
			||||||
    "import torch.nn as nn\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "from labml import experiment\n",
 | 
					    "from labml import experiment\n",
 | 
				
			||||||
    "from labml.configs import option\n",
 | 
					    "from labml_nn.transformers.feedback.experiment import Configs"
 | 
				
			||||||
    "from labml_helpers.module import Module\n",
 | 
					 | 
				
			||||||
    "from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs"
 | 
					 | 
				
			||||||
   ]
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "markdown",
 | 
					 | 
				
			||||||
   "metadata": {
 | 
					 | 
				
			||||||
    "id": "pO86KIJS52eR"
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "## Autoregressive model that uses the transformer"
 | 
					 | 
				
			||||||
   ]
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": null,
 | 
					 | 
				
			||||||
   "metadata": {
 | 
					 | 
				
			||||||
    "id": "WQ8VGpMGwZuj"
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "outputs": [],
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "class AutoregressiveModel(Module):\n",
 | 
					 | 
				
			||||||
    "    \"\"\"\n",
 | 
					 | 
				
			||||||
    "    ## Auto regressive model\n",
 | 
					 | 
				
			||||||
    "    \"\"\"\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "    def __init__(self, n_vocab: int, d_model: int, transformer: Module):\n",
 | 
					 | 
				
			||||||
    "        super().__init__()\n",
 | 
					 | 
				
			||||||
    "        # Token embedding module\n",
 | 
					 | 
				
			||||||
    "        self.src_embed = nn.Embedding(n_vocab, d_model)\n",
 | 
					 | 
				
			||||||
    "        self.transformer = transformer\n",
 | 
					 | 
				
			||||||
    "        self.generator = nn.Linear(d_model, n_vocab)\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "    def __call__(self, x: torch.Tensor):\n",
 | 
					 | 
				
			||||||
    "        x = self.src_embed(x)\n",
 | 
					 | 
				
			||||||
    "        # Embed the tokens (`src`) and run it through the the transformer\n",
 | 
					 | 
				
			||||||
    "        res = self.transformer(x)\n",
 | 
					 | 
				
			||||||
    "        # Generate logits of the next token\n",
 | 
					 | 
				
			||||||
    "        return self.generator(res), None"
 | 
					 | 
				
			||||||
   ]
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "markdown",
 | 
					 | 
				
			||||||
   "metadata": {
 | 
					 | 
				
			||||||
    "id": "JkoWbOdI58jg"
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "## Configs\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "We extend from [`NLPAutoRegressionConfigs`](https://github.com/lab-ml/nn/blob/master/labml_nn/experiments/nlp_autoregression.py) that defines base configurations, including datasets and dataloaders.\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "The values we set here are the defaults. These can be overridden with a configs dictionary when starting the experiment."
 | 
					 | 
				
			||||||
   ]
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": null,
 | 
					 | 
				
			||||||
   "metadata": {
 | 
					 | 
				
			||||||
    "id": "f07vAOaHwumr"
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "outputs": [],
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "class Configs(NLPAutoRegressionConfigs):\n",
 | 
					 | 
				
			||||||
    "    model: AutoregressiveModel\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "    d_model: int = 512\n",
 | 
					 | 
				
			||||||
    "    heads: int = 8\n",
 | 
					 | 
				
			||||||
    "    dropout: float = 0.0\n",
 | 
					 | 
				
			||||||
    "    d_ff: int = 2048\n",
 | 
					 | 
				
			||||||
    "    n_layers: int = 6"
 | 
					 | 
				
			||||||
   ]
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "markdown",
 | 
					 | 
				
			||||||
   "metadata": {
 | 
					 | 
				
			||||||
    "id": "IgX3Au_p6Z36"
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "Set the function to initialze `AutoregressiveModel`. This will be called when\n",
 | 
					 | 
				
			||||||
    "`Configs.model` is accessed. "
 | 
					 | 
				
			||||||
   ]
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": null,
 | 
					 | 
				
			||||||
   "metadata": {
 | 
					 | 
				
			||||||
    "id": "crH6MzKmw-SY"
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "outputs": [],
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "@option(Configs.model)\n",
 | 
					 | 
				
			||||||
    "def autoregressive_model(c: Configs):\n",
 | 
					 | 
				
			||||||
    "    from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \\\n",
 | 
					 | 
				
			||||||
    "        FeedbackAttention, FeedForward\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "    return AutoregressiveModel(\n",
 | 
					 | 
				
			||||||
    "        c.n_tokens, c.d_model,\n",
 | 
					 | 
				
			||||||
    "        FeedbackTransformer(\n",
 | 
					 | 
				
			||||||
    "            FeedbackTransformerLayer(d_model=c.d_model,\n",
 | 
					 | 
				
			||||||
    "                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout),\n",
 | 
					 | 
				
			||||||
    "                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),\n",
 | 
					 | 
				
			||||||
    "                                     dropout_prob=c.dropout),\n",
 | 
					 | 
				
			||||||
    "            c.n_layers)).to(c.device)"
 | 
					 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
@ -293,6 +187,8 @@
 | 
				
			|||||||
    "                    'prompt': 'It is',\n",
 | 
					    "                    'prompt': 'It is',\n",
 | 
				
			||||||
    "                    'prompt_separator': '',\n",
 | 
					    "                    'prompt_separator': '',\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "                    'model': 'feedback_transformer',\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
    "                    'train_loader': 'shuffled_train_loader',\n",
 | 
					    "                    'train_loader': 'shuffled_train_loader',\n",
 | 
				
			||||||
    "                    'valid_loader': 'shuffled_valid_loader',\n",
 | 
					    "                    'valid_loader': 'shuffled_valid_loader',\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
@ -554,18 +450,18 @@
 | 
				
			|||||||
     "evalue": "ignored",
 | 
					     "evalue": "ignored",
 | 
				
			||||||
     "output_type": "error",
 | 
					     "output_type": "error",
 | 
				
			||||||
     "traceback": [
 | 
					     "traceback": [
 | 
				
			||||||
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
 | 
					      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
 | 
				
			||||||
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
 | 
					      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
 | 
				
			||||||
      "\u001b[0;32m<ipython-input-11-fb962b998049>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Start the experiment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mexperiment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mconf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
 | 
					      "\u001B[0;32m<ipython-input-11-fb962b998049>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m()\u001B[0m\n\u001B[1;32m      1\u001B[0m \u001B[0;31m# Start the experiment\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      2\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mexperiment\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstart\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 3\u001B[0;31m     \u001B[0mconf\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    238\u001b[0m         \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    239\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_loop\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 240\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    242\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36mrun\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    238\u001B[0m         \u001B[0m_\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    239\u001B[0m         \u001B[0;32mfor\u001B[0m \u001B[0m_\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtraining_loop\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 240\u001B[0;31m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun_step\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    241\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    242\u001B[0m     \u001B[0;32mdef\u001B[0m \u001B[0msample\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36mrun_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    226\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    227\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtracker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 228\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    229\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    230\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtracker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36mrun_step\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    226\u001B[0m             \u001B[0;32mwith\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    227\u001B[0m                 \u001B[0;32mwith\u001B[0m \u001B[0mtracker\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamespace\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'train'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 228\u001B[0;31m                     \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    229\u001B[0m             \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mvalidator\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    230\u001B[0m                 \u001B[0;32mwith\u001B[0m \u001B[0mtracker\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamespace\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'valid'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    130\u001b[0m                 \u001b[0msm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    131\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_grad_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iterate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompleted\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36m__call__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    130\u001B[0m                 \u001B[0msm\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_epoch_start\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    131\u001B[0m         \u001B[0;32mwith\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_grad_enabled\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 132\u001B[0;31m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__iterate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    133\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    134\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcompleted\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36m__iterate\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    143\u001b[0m                 \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iterable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    147\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36m__iterate\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    143\u001B[0m                 \u001B[0mbatch\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnext\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__iterable\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    144\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 145\u001B[0;31m                 \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mbatch\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    146\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    147\u001B[0m                 \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_nn/experiments/nlp_autoregression.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m     70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m             \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_nn/experiments/nlp_autoregression.py\u001B[0m in \u001B[0;36mstep\u001B[0;34m(self, batch, batch_idx)\u001B[0m\n\u001B[1;32m     70\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     71\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 72\u001B[0;31m             \u001B[0mloss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m     73\u001B[0m             \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnn\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mutils\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mclip_grad_norm_\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmax_norm\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m1.\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     74\u001B[0m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0moptimizer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    219\u001b[0m                 \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m                 create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001B[0m in \u001B[0;36mbackward\u001B[0;34m(self, gradient, retain_graph, create_graph)\u001B[0m\n\u001B[1;32m    219\u001B[0m                 \u001B[0mretain_graph\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    220\u001B[0m                 create_graph=create_graph)\n\u001B[0;32m--> 221\u001B[0;31m         \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mautograd\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgradient\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    222\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    223\u001B[0m     \u001B[0;32mdef\u001B[0m \u001B[0mregister_hook\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mhook\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m    130\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m    131\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001B[0m in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001B[0m\n\u001B[1;32m    130\u001B[0m     Variable._execution_engine.run_backward(\n\u001B[1;32m    131\u001B[0m         \u001B[0mtensors\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgrad_tensors_\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 132\u001B[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001B[0m\u001B[1;32m    133\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    134\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/training_loop.py\u001b[0m in \u001b[0;36m__handler\u001b[0;34m(self, sig, frame)\u001b[0m\n\u001b[1;32m    162\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__finish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    163\u001b[0m             \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Killing loop...'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mText\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdanger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mold_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__str__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
					      "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/training_loop.py\u001B[0m in \u001B[0;36m__handler\u001B[0;34m(self, sig, frame)\u001B[0m\n\u001B[1;32m    162\u001B[0m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__finish\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    163\u001B[0m             \u001B[0mlogger\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'Killing loop...'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mText\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdanger\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 164\u001B[0;31m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mold_handler\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0msig\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mframe\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    165\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    166\u001B[0m     \u001B[0;32mdef\u001B[0m \u001B[0m__str__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
 | 
				
			||||||
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
 | 
					      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
 | 
				
			||||||
     ]
 | 
					     ]
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
   ],
 | 
					   ],
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										127
									
								
								labml_nn/transformers/feedback/experiment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								labml_nn/transformers/feedback/experiment.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,127 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					title: Train Feedback Transformer
 | 
				
			||||||
 | 
					summary: This is training code with notes for a feedback transformer.
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Train Feedback Transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This trains a [feedback transformer](index.html) model for auto-regression.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml import experiment
 | 
				
			||||||
 | 
					from labml.configs import option
 | 
				
			||||||
 | 
					from labml.utils.pytorch import get_modules
 | 
				
			||||||
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
 | 
				
			||||||
 | 
					from labml_nn.transformers import Encoder, Generator, TransformerConfigs
 | 
				
			||||||
 | 
					from labml_nn.transformers.utils import subsequent_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoregressiveModel(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Auto regressive model
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # Token embedding module
 | 
				
			||||||
 | 
					        self.src_embed = nn.Embedding(n_vocab, d_model)
 | 
				
			||||||
 | 
					        self.transformer = transformer
 | 
				
			||||||
 | 
					        self.generator = nn.Linear(d_model, n_vocab)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        x = self.src_embed(x)
 | 
				
			||||||
 | 
					        # Embed the tokens (`src`) and run it through the the transformer
 | 
				
			||||||
 | 
					        res = self.transformer(x)
 | 
				
			||||||
 | 
					        # Generate logits of the next token
 | 
				
			||||||
 | 
					        return self.generator(res), None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Configs(NLPAutoRegressionConfigs):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Configurations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The default configs can and will be over-ridden when we start the experiment
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model: AutoregressiveModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    d_model: int = 512
 | 
				
			||||||
 | 
					    heads: int = 8
 | 
				
			||||||
 | 
					    dropout: float = 0.0
 | 
				
			||||||
 | 
					    d_ff: int = 2048
 | 
				
			||||||
 | 
					    n_layers: int = 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@option(Configs.model)
 | 
				
			||||||
 | 
					def feedback_transformer(c: Configs):
 | 
				
			||||||
 | 
					    from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
 | 
				
			||||||
 | 
					        FeedbackAttention, FeedForward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return AutoregressiveModel(
 | 
				
			||||||
 | 
					        c.n_tokens, c.d_model,
 | 
				
			||||||
 | 
					        FeedbackTransformer(
 | 
				
			||||||
 | 
					            FeedbackTransformerLayer(d_model=c.d_model,
 | 
				
			||||||
 | 
					                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout),
 | 
				
			||||||
 | 
					                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
 | 
				
			||||||
 | 
					                                     dropout_prob=c.dropout),
 | 
				
			||||||
 | 
					            c.n_layers)).to(c.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@option(Configs.model)
 | 
				
			||||||
 | 
					def feedback_transformer_kv(c: Configs):
 | 
				
			||||||
 | 
					    from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \
 | 
				
			||||||
 | 
					        FeedbackAttention, FeedForward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return AutoregressiveModel(
 | 
				
			||||||
 | 
					        c.n_tokens, c.d_model,
 | 
				
			||||||
 | 
					        FeedbackTransformerKV(
 | 
				
			||||||
 | 
					            FeedbackTransformerLayer(d_model=c.d_model,
 | 
				
			||||||
 | 
					                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout,
 | 
				
			||||||
 | 
					                                                            is_kv_precomputed=True),
 | 
				
			||||||
 | 
					                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
 | 
				
			||||||
 | 
					                                     dropout_prob=c.dropout),
 | 
				
			||||||
 | 
					            c.n_layers, c.d_model, c.heads)).to(c.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    # Create experiment
 | 
				
			||||||
 | 
					    experiment.create(name="feedback_transformer")
 | 
				
			||||||
 | 
					    # Create configs
 | 
				
			||||||
 | 
					    conf = Configs()
 | 
				
			||||||
 | 
					    # Load configurations
 | 
				
			||||||
 | 
					    experiment.configs(conf,
 | 
				
			||||||
 | 
					                       # A dictionary of configurations to override
 | 
				
			||||||
 | 
					                       {'tokenizer': 'character',
 | 
				
			||||||
 | 
					                        'text': 'tiny_shakespeare',
 | 
				
			||||||
 | 
					                        'optimizer.learning_rate': 1.0,
 | 
				
			||||||
 | 
					                        'optimizer.optimizer': 'Noam',
 | 
				
			||||||
 | 
					                        'prompt': 'It is',
 | 
				
			||||||
 | 
					                        'prompt_separator': '',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        'model': 'feedback_transformer_kv',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        'train_loader': 'shuffled_train_loader',
 | 
				
			||||||
 | 
					                        'valid_loader': 'shuffled_valid_loader',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        'seq_len': 128,
 | 
				
			||||||
 | 
					                        'epochs': 128,
 | 
				
			||||||
 | 
					                        'batch_size': 64,
 | 
				
			||||||
 | 
					                        'inner_iterations': 25})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Set models for saving and loading
 | 
				
			||||||
 | 
					    experiment.add_pytorch_models(get_modules(conf))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Start the experiment
 | 
				
			||||||
 | 
					    with experiment.start():
 | 
				
			||||||
 | 
					        # `TrainValidConfigs.run`
 | 
				
			||||||
 | 
					        conf.run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user