feedback transformer update

This commit is contained in:
Varuna Jayasiri
2021-01-29 10:28:58 +05:30
parent 322a64e313
commit 2f1918f6db
3 changed files with 299 additions and 147 deletions

View File

@ -55,7 +55,8 @@ class FeedbackAttention(Module):
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
is_kv_precomputed: bool = False):
"""
* 'heads' is the number of attention heads
* `d_model` is the number of features in the transformer
@ -70,9 +71,13 @@ class FeedbackAttention(Module):
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
if not is_kv_precomputed:
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
else:
self.key = None
self.value = None
# Output layer
self.output = nn.Linear(d_model, d_model)
@ -117,7 +122,10 @@ class FeedbackAttention(Module):
query_pos_bias = self.query_pos_bias[None, :, :]
# $(Q + U^Q)^\top(K_j + U^K_j)$
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb)
return ac + bd
def __call__(self, *,
query: torch.Tensor,
@ -132,8 +140,10 @@ class FeedbackAttention(Module):
# `key` and `value` will then have shape `[seq_len, batch_size, heads, d_k]`
# and `query` will have shape `[batch_size, heads, d_k]`
query = self.query(query)
key = self.key(key)
value = self.value(value)
if self.key:
key = self.key(key)
if self.value:
value = self.value(value)
# Compute attention scores
# Results in a tensor of shape `[seq_len, batch_size, heads]`
@ -190,13 +200,14 @@ class FeedbackTransformerLayer(Module):
def __call__(self, *,
x: torch.Tensor,
mem: Optional[torch.Tensor]):
key: Optional[torch.Tensor],
value: Optional[torch.Tensor]):
# If there is memory
if mem is not None:
if key is not None:
# Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
# Run through self attention, i.e. keys and values are from self
self_attn = self.attn(query=z, key=mem, value=mem)
self_attn = self.attn(query=z, key=key, value=value)
# Add the self attention results
x = x + self.dropout(self_attn)
@ -255,7 +266,7 @@ class FeedbackTransformer(Module):
# Run through each layer
for layer in self.layers:
# Get layer output
x = layer(x=x, mem=mem_tensor)
x = layer(x=x, key=mem_tensor, value=mem_tensor)
# Append them to the list of layer outputs
layer_outputs.append(x)
@ -270,3 +281,121 @@ class FeedbackTransformer(Module):
res = torch.stack(res)
# Normalize the output
return self.norm(res)
class StackFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, memory, memory_grad, last, n):
ctx._mem_grad = memory_grad
ctx._n = n
return memory[:n + 1]
@staticmethod
def backward(ctx, grad_output):
n = ctx._n
memory_grad = ctx._mem_grad
memory_grad[:n + 1] += grad_output
return None, None, memory_grad[n], None
class Stack:
def __init__(self, max_len: int):
self.max_len = max_len
self.memory = None
self.memory_grad = None
self.last = None
self.n = -1
self.last_get_n = -1
def append(self, n: int, vector: torch.Tensor):
assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
with torch.no_grad():
if self.memory is None or self.memory.shape[1:] != vector.shape:
assert n == 0
self.memory = vector.new_zeros(self.max_len, *vector.shape, requires_grad=False)
self.memory_grad = vector.new_zeros(self.memory.shape, requires_grad=False)
elif n == 0:
self.memory_grad.fill_(0.)
# memory[n] = vector.detach()
self.memory.data[n] = vector.detach()
self.n = n
self.last = vector
def get(self):
self.last_get_n = self.n
return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
class FeedbackTransformerKV(Module):
"""
## Feedback Transformer Module
"""
def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
"""
* `layer` is the feedback transformer layer, which we clone for each layer
* `n_layers` is the number of layers in the transformer
"""
super().__init__()
# Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
# Memory vectors are computed as a weighted sum of representations of each layer.
# This is the weights parameter for that.
self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
# Softmax for weights before taking the weighted sum
self.softmax = nn.Softmax(0)
d_k = d_model // heads
self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
self.mem_key = Stack(512)
self.mem_value = Stack(512)
def __call__(self, x_seq: torch.Tensor):
"""
* `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
"""
# Split the input to a list along the sequence axis
x_seq = torch.unbind(x_seq, dim=0)
# List to store the outputs
res = []
# For each input step
for step, x in enumerate(x_seq):
# List to store layer outputs
layer_outputs = [x]
# If there is memory, stack them into a vector
key_tensor = None
value_tensor = None
if step > 0:
key_tensor = self.mem_key.get()
value_tensor = self.mem_value.get()
# Run through each layer
for layer in self.layers:
# Get layer output
x = layer(x=x, key=key_tensor, value=value_tensor)
# Append them to the list of layer outputs
layer_outputs.append(x)
# Stack the layer outputs to a tensor
layer_outputs = torch.stack(layer_outputs)
# Calculate the memory vector as a weighted sum of layer outputs
mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
self.mem_key.append(step, self.key(mem))
self.mem_value.append(step, self.value(mem))
# Append the output to results
res.append(x)
# Stack the output tensors
res = torch.stack(res)
# Normalize the output
return self.norm(res)

View File

@ -39,33 +39,33 @@
"output_type": "stream",
"text": [
"Collecting labml-nn\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n",
"\u001b[K |████████████████████████████████| 112kB 13.7MB/s \n",
"\u001b[?25hCollecting labml>=0.4.86\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n",
"\u001b[K |████████████████████████████████| 102kB 11.1MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n",
"\u001B[?25l Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n",
"\u001B[K |████████████████████████████████| 112kB 13.7MB/s \n",
"\u001B[?25hCollecting labml>=0.4.86\n",
"\u001B[?25l Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n",
"\u001B[K |████████████████████████████████| 102kB 11.1MB/s \n",
"\u001B[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n",
"Collecting einops\n",
" Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
"Collecting labml-helpers>=0.4.72\n",
" Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n",
"Collecting pyyaml>=5.3.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n",
"\u001b[K |████████████████████████████████| 276kB 40.5MB/s \n",
"\u001b[?25hCollecting gitpython\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n",
"\u001b[K |████████████████████████████████| 163kB 50.4MB/s \n",
"\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
"\u001B[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n",
"\u001B[K |████████████████████████████████| 276kB 40.5MB/s \n",
"\u001B[?25hCollecting gitpython\n",
"\u001B[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n",
"\u001B[K |████████████████████████████████| 163kB 50.4MB/s \n",
"\u001B[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n",
"Collecting gitdb<5,>=4.0.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
"\u001b[K |████████████████████████████████| 71kB 12.7MB/s \n",
"\u001b[?25hCollecting smmap<4,>=3.0.1\n",
"\u001B[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
"\u001B[K |████████████████████████████████| 71kB 12.7MB/s \n",
"\u001B[?25hCollecting smmap<4,>=3.0.1\n",
" Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
"Building wheels for collected packages: pyyaml\n",
" Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for pyyaml (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
" Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=132f39d02b291cdc60b9eff7c14b051ee0f520f790ac3e3bdaf6823e9bf7fda3\n",
" Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n",
"Successfully built pyyaml\n",
@ -98,114 +98,8 @@
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from labml import experiment\n",
"from labml.configs import option\n",
"from labml_helpers.module import Module\n",
"from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pO86KIJS52eR"
},
"source": [
"## Autoregressive model that uses the transformer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WQ8VGpMGwZuj"
},
"outputs": [],
"source": [
"class AutoregressiveModel(Module):\n",
" \"\"\"\n",
" ## Auto regressive model\n",
" \"\"\"\n",
"\n",
" def __init__(self, n_vocab: int, d_model: int, transformer: Module):\n",
" super().__init__()\n",
" # Token embedding module\n",
" self.src_embed = nn.Embedding(n_vocab, d_model)\n",
" self.transformer = transformer\n",
" self.generator = nn.Linear(d_model, n_vocab)\n",
"\n",
" def __call__(self, x: torch.Tensor):\n",
" x = self.src_embed(x)\n",
" # Embed the tokens (`src`) and run it through the the transformer\n",
" res = self.transformer(x)\n",
" # Generate logits of the next token\n",
" return self.generator(res), None"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JkoWbOdI58jg"
},
"source": [
"## Configs\n",
"\n",
"We extend from [`NLPAutoRegressionConfigs`](https://github.com/lab-ml/nn/blob/master/labml_nn/experiments/nlp_autoregression.py) that defines base configurations, including datasets and dataloaders.\n",
"\n",
"The values we set here are the defaults. These can be overridden with a configs dictionary when starting the experiment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f07vAOaHwumr"
},
"outputs": [],
"source": [
"class Configs(NLPAutoRegressionConfigs):\n",
" model: AutoregressiveModel\n",
"\n",
" d_model: int = 512\n",
" heads: int = 8\n",
" dropout: float = 0.0\n",
" d_ff: int = 2048\n",
" n_layers: int = 6"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IgX3Au_p6Z36"
},
"source": [
"Set the function to initialze `AutoregressiveModel`. This will be called when\n",
"`Configs.model` is accessed. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "crH6MzKmw-SY"
},
"outputs": [],
"source": [
"@option(Configs.model)\n",
"def autoregressive_model(c: Configs):\n",
" from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \\\n",
" FeedbackAttention, FeedForward\n",
"\n",
" return AutoregressiveModel(\n",
" c.n_tokens, c.d_model,\n",
" FeedbackTransformer(\n",
" FeedbackTransformerLayer(d_model=c.d_model,\n",
" attn=FeedbackAttention(c.heads, c.d_model, c.dropout),\n",
" feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),\n",
" dropout_prob=c.dropout),\n",
" c.n_layers)).to(c.device)"
"from labml_nn.transformers.feedback.experiment import Configs"
]
},
{
@ -293,6 +187,8 @@
" 'prompt': 'It is',\n",
" 'prompt_separator': '',\n",
"\n",
" 'model': 'feedback_transformer',\n",
"\n",
" 'train_loader': 'shuffled_train_loader',\n",
" 'valid_loader': 'shuffled_valid_loader',\n",
"\n",
@ -554,18 +450,18 @@
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-11-fb962b998049>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Start the experiment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mexperiment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mconf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_loop\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 240\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36mrun_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtracker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 228\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 229\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtracker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0msm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_grad_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iterate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompleted\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001b[0m in \u001b[0;36m__iterate\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iterable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_nn/experiments/nlp_autoregression.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_train\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 130\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 131\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/training_loop.py\u001b[0m in \u001b[0;36m__handler\u001b[0;34m(self, sig, frame)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__finish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Killing loop...'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mText\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdanger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mold_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__str__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m<ipython-input-11-fb962b998049>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m()\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;31m# Start the experiment\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 2\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mexperiment\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstart\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 3\u001B[0;31m \u001B[0mconf\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36mrun\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 238\u001B[0m \u001B[0m_\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 239\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0m_\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtraining_loop\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 240\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun_step\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 241\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 242\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0msample\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36mrun_step\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 226\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 227\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mtracker\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamespace\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'train'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 228\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 229\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mvalidator\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 230\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mtracker\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamespace\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'valid'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36m__call__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 130\u001B[0m \u001B[0msm\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_epoch_start\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 131\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_grad_enabled\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 132\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__iterate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 133\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 134\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcompleted\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/train_valid.py\u001B[0m in \u001B[0;36m__iterate\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 143\u001B[0m \u001B[0mbatch\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnext\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__iterable\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 144\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 145\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mbatch\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 146\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 147\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_batch_index\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_nn/experiments/nlp_autoregression.py\u001B[0m in \u001B[0;36mstep\u001B[0;34m(self, batch, batch_idx)\u001B[0m\n\u001B[1;32m 70\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 71\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mis_train\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 72\u001B[0;31m \u001B[0mloss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 73\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnn\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mutils\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mclip_grad_norm_\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmax_norm\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m1.\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 74\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0moptimizer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001B[0m in \u001B[0;36mbackward\u001B[0;34m(self, gradient, retain_graph, create_graph)\u001B[0m\n\u001B[1;32m 219\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 220\u001B[0m create_graph=create_graph)\n\u001B[0;32m--> 221\u001B[0;31m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mautograd\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgradient\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 222\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 223\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mregister_hook\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mhook\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001B[0m in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001B[0m\n\u001B[1;32m 130\u001B[0m Variable._execution_engine.run_backward(\n\u001B[1;32m 131\u001B[0m \u001B[0mtensors\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgrad_tensors_\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mretain_graph\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcreate_graph\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 132\u001B[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001B[0m\u001B[1;32m 133\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 134\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/usr/local/lib/python3.6/dist-packages/labml_helpers/training_loop.py\u001B[0m in \u001B[0;36m__handler\u001B[0;34m(self, sig, frame)\u001B[0m\n\u001B[1;32m 162\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__finish\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 163\u001B[0m \u001B[0mlogger\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'Killing loop...'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mText\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdanger\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 164\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mold_handler\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0msig\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mframe\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 165\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 166\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m__str__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
]
}
],
@ -612,4 +508,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@ -0,0 +1,127 @@
"""
---
title: Train Feedback Transformer
summary: This is training code with notes for a feedback transformer.
---
# Train Feedback Transformer
This trains a [feedback transformer](index.html) model for auto-regression.
"""
import torch
from torch import nn
from labml import experiment
from labml.configs import option
from labml.utils.pytorch import get_modules
from labml_helpers.module import Module
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import Encoder, Generator, TransformerConfigs
from labml_nn.transformers.utils import subsequent_mask
class AutoregressiveModel(Module):
"""
## Auto regressive model
"""
def __init__(self, n_vocab: int, d_model: int, transformer: Module):
super().__init__()
# Token embedding module
self.src_embed = nn.Embedding(n_vocab, d_model)
self.transformer = transformer
self.generator = nn.Linear(d_model, n_vocab)
def __call__(self, x: torch.Tensor):
x = self.src_embed(x)
# Embed the tokens (`src`) and run it through the the transformer
res = self.transformer(x)
# Generate logits of the next token
return self.generator(res), None
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
The default configs can and will be over-ridden when we start the experiment
"""
model: AutoregressiveModel
d_model: int = 512
heads: int = 8
dropout: float = 0.0
d_ff: int = 2048
n_layers: int = 6
@option(Configs.model)
def feedback_transformer(c: Configs):
from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
FeedbackAttention, FeedForward
return AutoregressiveModel(
c.n_tokens, c.d_model,
FeedbackTransformer(
FeedbackTransformerLayer(d_model=c.d_model,
attn=FeedbackAttention(c.heads, c.d_model, c.dropout),
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
dropout_prob=c.dropout),
c.n_layers)).to(c.device)
@option(Configs.model)
def feedback_transformer_kv(c: Configs):
from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \
FeedbackAttention, FeedForward
return AutoregressiveModel(
c.n_tokens, c.d_model,
FeedbackTransformerKV(
FeedbackTransformerLayer(d_model=c.d_model,
attn=FeedbackAttention(c.heads, c.d_model, c.dropout,
is_kv_precomputed=True),
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
dropout_prob=c.dropout),
c.n_layers, c.d_model, c.heads)).to(c.device)
def main():
# Create experiment
experiment.create(name="feedback_transformer")
# Create configs
conf = Configs()
# Load configurations
experiment.configs(conf,
# A dictionary of configurations to override
{'tokenizer': 'character',
'text': 'tiny_shakespeare',
'optimizer.learning_rate': 1.0,
'optimizer.optimizer': 'Noam',
'prompt': 'It is',
'prompt_separator': '',
'model': 'feedback_transformer_kv',
'train_loader': 'shuffled_train_loader',
'valid_loader': 'shuffled_valid_loader',
'seq_len': 128,
'epochs': 128,
'batch_size': 64,
'inner_iterations': 25})
# Set models for saving and loading
experiment.add_pytorch_models(get_modules(conf))
# Start the experiment
with experiment.start():
# `TrainValidConfigs.run`
conf.run()
if __name__ == '__main__':
main()