diff --git a/labml_nn/transformers/feedback/__init__.py b/labml_nn/transformers/feedback/__init__.py index a5ecc282..6db5e377 100644 --- a/labml_nn/transformers/feedback/__init__.py +++ b/labml_nn/transformers/feedback/__init__.py @@ -55,7 +55,8 @@ class FeedbackAttention(Module): $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$ """ - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *, + is_kv_precomputed: bool = False): """ * 'heads' is the number of attention heads * `d_model` is the number of features in the transformer @@ -70,9 +71,13 @@ class FeedbackAttention(Module): self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) - self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) - self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) + if not is_kv_precomputed: + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) + self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) + else: + self.key = None + self.value = None # Output layer self.output = nn.Linear(d_model, d_model) @@ -117,7 +122,10 @@ class FeedbackAttention(Module): query_pos_bias = self.query_pos_bias[None, :, :] # $(Q + U^Q)^\top(K_j + U^K_j)$ - return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :]) + ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key) + bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + + return ac + bd def __call__(self, *, query: torch.Tensor, @@ -132,8 +140,10 @@ class FeedbackAttention(Module): # `key` and `value` will then have shape `[seq_len, batch_size, heads, d_k]` # and `query` will have shape `[batch_size, heads, d_k]` query = self.query(query) - key = self.key(key) - value = self.value(value) + if self.key: + key = self.key(key) + if self.value: + value = self.value(value) # Compute attention scores # Results in a tensor of shape `[seq_len, batch_size, heads]` @@ -190,13 +200,14 @@ class FeedbackTransformerLayer(Module): def __call__(self, *, x: torch.Tensor, - mem: Optional[torch.Tensor]): + key: Optional[torch.Tensor], + value: Optional[torch.Tensor]): # If there is memory - if mem is not None: + if key is not None: # Normalize the vectors before doing self attention z = self.norm_self_attn(x) # Run through self attention, i.e. keys and values are from self - self_attn = self.attn(query=z, key=mem, value=mem) + self_attn = self.attn(query=z, key=key, value=value) # Add the self attention results x = x + self.dropout(self_attn) @@ -255,7 +266,7 @@ class FeedbackTransformer(Module): # Run through each layer for layer in self.layers: # Get layer output - x = layer(x=x, mem=mem_tensor) + x = layer(x=x, key=mem_tensor, value=mem_tensor) # Append them to the list of layer outputs layer_outputs.append(x) @@ -270,3 +281,121 @@ class FeedbackTransformer(Module): res = torch.stack(res) # Normalize the output return self.norm(res) + + +class StackFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, memory, memory_grad, last, n): + ctx._mem_grad = memory_grad + ctx._n = n + return memory[:n + 1] + + @staticmethod + def backward(ctx, grad_output): + n = ctx._n + memory_grad = ctx._mem_grad + memory_grad[:n + 1] += grad_output + return None, None, memory_grad[n], None + + +class Stack: + def __init__(self, max_len: int): + self.max_len = max_len + self.memory = None + self.memory_grad = None + self.last = None + self.n = -1 + self.last_get_n = -1 + + def append(self, n: int, vector: torch.Tensor): + assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}" + + with torch.no_grad(): + if self.memory is None or self.memory.shape[1:] != vector.shape: + assert n == 0 + self.memory = vector.new_zeros(self.max_len, *vector.shape, requires_grad=False) + self.memory_grad = vector.new_zeros(self.memory.shape, requires_grad=False) + elif n == 0: + self.memory_grad.fill_(0.) + + # memory[n] = vector.detach() + self.memory.data[n] = vector.detach() + self.n = n + + self.last = vector + + def get(self): + self.last_get_n = self.n + return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n) + + +class FeedbackTransformerKV(Module): + """ + ## Feedback Transformer Module + """ + + def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int): + """ + * `layer` is the feedback transformer layer, which we clone for each layer + * `n_layers` is the number of layers in the transformer + """ + + super().__init__() + # Make copies of the transformer layer + self.layers = clone_module_list(layer, n_layers) + # Final normalization layer + self.norm = nn.LayerNorm([layer.size]) + # Memory vectors are computed as a weighted sum of representations of each layer. + # This is the weights parameter for that. + self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True) + # Softmax for weights before taking the weighted sum + self.softmax = nn.Softmax(0) + + d_k = d_model // heads + self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False) + self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False) + + self.mem_key = Stack(512) + self.mem_value = Stack(512) + + def __call__(self, x_seq: torch.Tensor): + """ + * `x_seq` is the input with shape `[seq_len, batch_size, d_model]` + """ + + # Split the input to a list along the sequence axis + x_seq = torch.unbind(x_seq, dim=0) + # List to store the outputs + res = [] + # For each input step + for step, x in enumerate(x_seq): + # List to store layer outputs + layer_outputs = [x] + + # If there is memory, stack them into a vector + key_tensor = None + value_tensor = None + if step > 0: + key_tensor = self.mem_key.get() + value_tensor = self.mem_value.get() + + # Run through each layer + for layer in self.layers: + # Get layer output + x = layer(x=x, key=key_tensor, value=value_tensor) + # Append them to the list of layer outputs + layer_outputs.append(x) + + # Stack the layer outputs to a tensor + layer_outputs = torch.stack(layer_outputs) + # Calculate the memory vector as a weighted sum of layer outputs + mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)) + self.mem_key.append(step, self.key(mem)) + self.mem_value.append(step, self.value(mem)) + # Append the output to results + res.append(x) + + # Stack the output tensors + res = torch.stack(res) + # Normalize the output + return self.norm(res) diff --git a/labml_nn/transformers/feedback/experiment.ipynb b/labml_nn/transformers/feedback/experiment.ipynb index 2beef73c..f606ff1d 100644 --- a/labml_nn/transformers/feedback/experiment.ipynb +++ b/labml_nn/transformers/feedback/experiment.ipynb @@ -39,33 +39,33 @@ "output_type": "stream", "text": [ "Collecting labml-nn\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n", - "\u001b[K |████████████████████████████████| 112kB 13.7MB/s \n", - "\u001b[?25hCollecting labml>=0.4.86\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n", - "\u001b[K |████████████████████████████████| 102kB 11.1MB/s \n", - "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n", + "\u001B[?25l Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n", + "\u001B[K |████████████████████████████████| 112kB 13.7MB/s \n", + "\u001B[?25hCollecting labml>=0.4.86\n", + "\u001B[?25l Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n", + "\u001B[K |████████████████████████████████| 102kB 11.1MB/s \n", + "\u001B[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n", "Collecting einops\n", " Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", "Collecting labml-helpers>=0.4.72\n", " Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n", "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n", "Collecting pyyaml>=5.3.1\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", - "\u001b[K |████████████████████████████████| 276kB 40.5MB/s \n", - "\u001b[?25hCollecting gitpython\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", - "\u001b[K |████████████████████████████████| 163kB 50.4MB/s \n", - "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", + "\u001B[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", + "\u001B[K |████████████████████████████████| 276kB 40.5MB/s \n", + "\u001B[?25hCollecting gitpython\n", + "\u001B[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", + "\u001B[K |████████████████████████████████| 163kB 50.4MB/s \n", + "\u001B[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n", "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n", "Collecting gitdb<5,>=4.0.1\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", - "\u001b[K |████████████████████████████████| 71kB 12.7MB/s \n", - "\u001b[?25hCollecting smmap<4,>=3.0.1\n", + "\u001B[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", + "\u001B[K |████████████████████████████████| 71kB 12.7MB/s \n", + "\u001B[?25hCollecting smmap<4,>=3.0.1\n", " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", "Building wheels for collected packages: pyyaml\n", - " Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for pyyaml (setup.py) ... \u001B[?25l\u001B[?25hdone\n", " Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=132f39d02b291cdc60b9eff7c14b051ee0f520f790ac3e3bdaf6823e9bf7fda3\n", " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n", "Successfully built pyyaml\n", @@ -98,114 +98,8 @@ }, "outputs": [], "source": [ - "import torch\n", - "import torch.nn as nn\n", - "\n", "from labml import experiment\n", - "from labml.configs import option\n", - "from labml_helpers.module import Module\n", - "from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pO86KIJS52eR" - }, - "source": [ - "## Autoregressive model that uses the transformer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WQ8VGpMGwZuj" - }, - "outputs": [], - "source": [ - "class AutoregressiveModel(Module):\n", - " \"\"\"\n", - " ## Auto regressive model\n", - " \"\"\"\n", - "\n", - " def __init__(self, n_vocab: int, d_model: int, transformer: Module):\n", - " super().__init__()\n", - " # Token embedding module\n", - " self.src_embed = nn.Embedding(n_vocab, d_model)\n", - " self.transformer = transformer\n", - " self.generator = nn.Linear(d_model, n_vocab)\n", - "\n", - " def __call__(self, x: torch.Tensor):\n", - " x = self.src_embed(x)\n", - " # Embed the tokens (`src`) and run it through the the transformer\n", - " res = self.transformer(x)\n", - " # Generate logits of the next token\n", - " return self.generator(res), None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JkoWbOdI58jg" - }, - "source": [ - "## Configs\n", - "\n", - "We extend from [`NLPAutoRegressionConfigs`](https://github.com/lab-ml/nn/blob/master/labml_nn/experiments/nlp_autoregression.py) that defines base configurations, including datasets and dataloaders.\n", - "\n", - "The values we set here are the defaults. These can be overridden with a configs dictionary when starting the experiment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f07vAOaHwumr" - }, - "outputs": [], - "source": [ - "class Configs(NLPAutoRegressionConfigs):\n", - " model: AutoregressiveModel\n", - "\n", - " d_model: int = 512\n", - " heads: int = 8\n", - " dropout: float = 0.0\n", - " d_ff: int = 2048\n", - " n_layers: int = 6" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IgX3Au_p6Z36" - }, - "source": [ - "Set the function to initialze `AutoregressiveModel`. This will be called when\n", - "`Configs.model` is accessed. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "crH6MzKmw-SY" - }, - "outputs": [], - "source": [ - "@option(Configs.model)\n", - "def autoregressive_model(c: Configs):\n", - " from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \\\n", - " FeedbackAttention, FeedForward\n", - "\n", - " return AutoregressiveModel(\n", - " c.n_tokens, c.d_model,\n", - " FeedbackTransformer(\n", - " FeedbackTransformerLayer(d_model=c.d_model,\n", - " attn=FeedbackAttention(c.heads, c.d_model, c.dropout),\n", - " feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),\n", - " dropout_prob=c.dropout),\n", - " c.n_layers)).to(c.device)" + "from labml_nn.transformers.feedback.experiment import Configs" ] }, { @@ -293,6 +187,8 @@ " 'prompt': 'It is',\n", " 'prompt_separator': '',\n", "\n", + " 'model': 'feedback_transformer',\n", + "\n", " 'train_loader': 'shuffled_train_loader',\n", " 'valid_loader': 'shuffled_valid_loader',\n", "\n", @@ -554,18 +450,18 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Start the experiment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mexperiment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mconf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_loop\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 240\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36mrun_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtracker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 228\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 229\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtracker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0msm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_grad_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iterate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompleted\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36m__iterate\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iterable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_nn/experiments/nlp_autoregression.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 130\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 131\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/training_loop.py\u001b[0m in \u001b[0;36m__handler\u001b[0;34m(self, sig, frame)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__finish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Killing loop...'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mText\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdanger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mold_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__str__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m()\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;31m# Start the experiment\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 2\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mexperiment\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstart\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 3\u001B[0;31m \u001B[0mconf\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36mrun\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 238\u001B[0m \u001B[0m_\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 239\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0m_\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtraining_loop\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 240\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun_step\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 241\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 242\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0msample\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36mrun_step\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 226\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 227\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mtracker\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamespace\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'train'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 228\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 229\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mvalidator\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 230\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mtracker\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamespace\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'valid'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36m__call__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 130\u001B[0m \u001B[0msm\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_epoch_start\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 131\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_grad_enabled\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 132\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__iterate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 133\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 134\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcompleted\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36m__iterate\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 143\u001B[0m \u001B[0mbatch\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnext\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__iterable\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 144\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 145\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mbatch\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 146\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 147\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_nn/experiments/nlp_autoregression.py\u001B[0m in \u001B[0;36mstep\u001B[0;34m(self, batch, batch_idx)\u001B[0m\n\u001B[1;32m 70\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 71\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 72\u001B[0;31m \u001B[0mloss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 73\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnn\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mutils\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mclip_grad_norm_\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmax_norm\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m1.\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 74\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0moptimizer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001B[0m in \u001B[0;36mbackward\u001B[0;34m(self, gradient, retain_graph, create_graph)\u001B[0m\n\u001B[1;32m 219\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 220\u001B[0m create_graph=create_graph)\n\u001B[0;32m--> 221\u001B[0;31m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mautograd\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgradient\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 222\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 223\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mregister_hook\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mhook\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001B[0m in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001B[0m\n\u001B[1;32m 130\u001B[0m Variable._execution_engine.run_backward(\n\u001B[1;32m 131\u001B[0m \u001B[0mtensors\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgrad_tensors_\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 132\u001B[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001B[0m\u001B[1;32m 133\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 134\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/training_loop.py\u001B[0m in \u001B[0;36m__handler\u001B[0;34m(self, sig, frame)\u001B[0m\n\u001B[1;32m 162\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__finish\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 163\u001B[0m \u001B[0mlogger\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'Killing loop...'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mText\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdanger\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 164\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mold_handler\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0msig\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mframe\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 165\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 166\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m__str__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mKeyboardInterrupt\u001B[0m: " ] } ], @@ -612,4 +508,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/labml_nn/transformers/feedback/experiment.py b/labml_nn/transformers/feedback/experiment.py new file mode 100644 index 00000000..4499fb17 --- /dev/null +++ b/labml_nn/transformers/feedback/experiment.py @@ -0,0 +1,127 @@ +""" +--- +title: Train Feedback Transformer +summary: This is training code with notes for a feedback transformer. +--- + +# Train Feedback Transformer + +This trains a [feedback transformer](index.html) model for auto-regression. +""" + +import torch +from torch import nn + +from labml import experiment +from labml.configs import option +from labml.utils.pytorch import get_modules +from labml_helpers.module import Module + +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs +from labml_nn.transformers import Encoder, Generator, TransformerConfigs +from labml_nn.transformers.utils import subsequent_mask + + +class AutoregressiveModel(Module): + """ + ## Auto regressive model + """ + + def __init__(self, n_vocab: int, d_model: int, transformer: Module): + super().__init__() + # Token embedding module + self.src_embed = nn.Embedding(n_vocab, d_model) + self.transformer = transformer + self.generator = nn.Linear(d_model, n_vocab) + + def __call__(self, x: torch.Tensor): + x = self.src_embed(x) + # Embed the tokens (`src`) and run it through the the transformer + res = self.transformer(x) + # Generate logits of the next token + return self.generator(res), None + + +class Configs(NLPAutoRegressionConfigs): + """ + ## Configurations + + The default configs can and will be over-ridden when we start the experiment + """ + + model: AutoregressiveModel + + d_model: int = 512 + heads: int = 8 + dropout: float = 0.0 + d_ff: int = 2048 + n_layers: int = 6 + + +@option(Configs.model) +def feedback_transformer(c: Configs): + from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \ + FeedbackAttention, FeedForward + + return AutoregressiveModel( + c.n_tokens, c.d_model, + FeedbackTransformer( + FeedbackTransformerLayer(d_model=c.d_model, + attn=FeedbackAttention(c.heads, c.d_model, c.dropout), + feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), + dropout_prob=c.dropout), + c.n_layers)).to(c.device) + + +@option(Configs.model) +def feedback_transformer_kv(c: Configs): + from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \ + FeedbackAttention, FeedForward + + return AutoregressiveModel( + c.n_tokens, c.d_model, + FeedbackTransformerKV( + FeedbackTransformerLayer(d_model=c.d_model, + attn=FeedbackAttention(c.heads, c.d_model, c.dropout, + is_kv_precomputed=True), + feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), + dropout_prob=c.dropout), + c.n_layers, c.d_model, c.heads)).to(c.device) + + +def main(): + # Create experiment + experiment.create(name="feedback_transformer") + # Create configs + conf = Configs() + # Load configurations + experiment.configs(conf, + # A dictionary of configurations to override + {'tokenizer': 'character', + 'text': 'tiny_shakespeare', + 'optimizer.learning_rate': 1.0, + 'optimizer.optimizer': 'Noam', + 'prompt': 'It is', + 'prompt_separator': '', + + 'model': 'feedback_transformer_kv', + + 'train_loader': 'shuffled_train_loader', + 'valid_loader': 'shuffled_valid_loader', + + 'seq_len': 128, + 'epochs': 128, + 'batch_size': 64, + 'inner_iterations': 25}) + + # Set models for saving and loading + experiment.add_pytorch_models(get_modules(conf)) + + # Start the experiment + with experiment.start(): + # `TrainValidConfigs.run` + conf.run() + + +if __name__ == '__main__': + main()