diff --git a/labml_nn/transformers/feedback/experiment.ipynb b/labml_nn/transformers/feedback/experiment.ipynb new file mode 100644 index 00000000..d3f53ad3 --- /dev/null +++ b/labml_nn/transformers/feedback/experiment.ipynb @@ -0,0 +1,645 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Feedback Transformer", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2" + }, + "source": [ + "[](https://github.com/lab-ml/nn)\n", + "[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb) \n", + "\n", + "## Feedback Transformer\n", + "\n", + "This is an experiment training Shakespeare dataset with Feedback Transformer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9" + }, + "source": [ + "Install the `labml-nn` package" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZCzmCrAIVg0L", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "aa1fe63d-1755-4394-dcdf-9897ac6c1ee4" + }, + "source": [ + "!pip install labml-nn" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting labml-nn\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/4d/ab1bc1578d83bae243118abe5c89bc9995d0195ee1d03960cae42ff39879/labml_nn-0.4.77-py3-none-any.whl (103kB)\n", + "\u001b[K |████████████████████████████████| 112kB 13.7MB/s \n", + "\u001b[?25hCollecting labml>=0.4.86\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n", + "\u001b[K |████████████████████████████████| 102kB 11.1MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.4)\n", + "Collecting einops\n", + " Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", + "Collecting labml-helpers>=0.4.72\n", + " Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n", + "Collecting pyyaml>=5.3.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", + "\u001b[K |████████████████████████████████| 276kB 40.5MB/s \n", + "\u001b[?25hCollecting gitpython\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", + "\u001b[K |████████████████████████████████| 163kB 50.4MB/s \n", + "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n", + "Collecting gitdb<5,>=4.0.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", + "\u001b[K |████████████████████████████████| 71kB 12.7MB/s \n", + "\u001b[?25hCollecting smmap<4,>=3.0.1\n", + " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", + "Building wheels for collected packages: pyyaml\n", + " Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=132f39d02b291cdc60b9eff7c14b051ee0f520f790ac3e3bdaf6823e9bf7fda3\n", + " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n", + "Successfully built pyyaml\n", + "Installing collected packages: pyyaml, smmap, gitdb, gitpython, labml, einops, labml-helpers, labml-nn\n", + " Found existing installation: PyYAML 3.13\n", + " Uninstalling PyYAML-3.13:\n", + " Successfully uninstalled PyYAML-3.13\n", + "Successfully installed einops-0.3.0 gitdb-4.0.5 gitpython-3.1.12 labml-0.4.89 labml-helpers-0.4.72 labml-nn-0.4.77 pyyaml-5.3.1 smmap-3.0.4\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0hJXx_g0wS2C" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from labml import experiment\n", + "from labml.configs import option\n", + "from labml_helpers.module import Module\n", + "from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pO86KIJS52eR" + }, + "source": [ + "## Autoregressive model that uses the transformer" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WQ8VGpMGwZuj" + }, + "source": [ + "class AutoregressiveModel(Module):\n", + " \"\"\"\n", + " ## Auto regressive model\n", + " \"\"\"\n", + "\n", + " def __init__(self, n_vocab: int, d_model: int, transformer: Module):\n", + " super().__init__()\n", + " # Token embedding module\n", + " self.src_embed = nn.Embedding(n_vocab, d_model)\n", + " self.transformer = transformer\n", + " self.generator = nn.Linear(d_model, n_vocab)\n", + "\n", + " def __call__(self, x: torch.Tensor):\n", + " x = self.src_embed(x)\n", + " # Embed the tokens (`src`) and run it through the the transformer\n", + " res = self.transformer(x)\n", + " # Generate logits of the next token\n", + " return self.generator(res), None" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JkoWbOdI58jg" + }, + "source": [ + "## Configs\n", + "\n", + "We extend from [`NLPAutoRegressionConfigs`](https://github.com/lab-ml/nn/blob/master/labml_nn/experiments/nlp_autoregression.py) that defines base configurations, including datasets and dataloaders.\n", + "\n", + "The values we set here are the defaults. These can be overridden with a configs dictionary when starting the experiment." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "f07vAOaHwumr" + }, + "source": [ + "class Configs(NLPAutoRegressionConfigs):\n", + " model: AutoregressiveModel\n", + "\n", + " d_model: int = 512\n", + " heads: int = 8\n", + " dropout: float = 0.0\n", + " d_ff: int = 2048\n", + " n_layers: int = 6" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IgX3Au_p6Z36" + }, + "source": [ + "Set the function to initialze `AutoregressiveModel`. This will be called when\n", + "`Configs.model` is accessed. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "crH6MzKmw-SY" + }, + "source": [ + "@option(Configs.model)\n", + "def autoregressive_model(c: Configs):\n", + " from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \\\n", + " FeedbackAttention, FeedForward\n", + "\n", + " return AutoregressiveModel(\n", + " c.n_tokens, c.d_model,\n", + " FeedbackTransformer(\n", + " FeedbackTransformerLayer(d_model=c.d_model,\n", + " attn=FeedbackAttention(c.heads, c.d_model, c.dropout),\n", + " feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),\n", + " dropout_prob=c.dropout),\n", + " c.n_layers)).to(c.device)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lpggo0wM6qb-" + }, + "source": [ + "Create an experiment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bFcr9k-l4cAg" + }, + "source": [ + "experiment.create(name=\"feedback_transformer\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-OnHLi626tJt" + }, + "source": [ + "Initialize configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Piz0c5f44hRo" + }, + "source": [ + "conf = Configs()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwMzCqpD6vkL" + }, + "source": [ + "Set experiment configurations and assign a configurations dictionary to override configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "e6hmQhTw4nks", + "outputId": "91d99011-7a61-48fa-ee5c-2fa845883cec" + }, + "source": [ + "experiment.configs(conf,\n", + " {'tokenizer': 'character',\n", + " 'text': 'tiny_shakespeare',\n", + " 'optimizer.learning_rate': 1.0,\n", + " 'optimizer.optimizer': 'Noam',\n", + " 'prompt': 'It is',\n", + " 'prompt_separator': '',\n", + "\n", + " 'train_loader': 'shuffled_train_loader',\n", + " 'valid_loader': 'shuffled_valid_loader',\n", + "\n", + " 'seq_len': 64,\n", + " 'epochs': 128,\n", + " 'batch_size': 80,\n", + " 'inner_iterations': 25})" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "Prepare model...\n", + " Prepare n_tokens...\n", + " Prepare text...\n", + " Prepare tokenizer...[DONE]\t3.28ms\n", + " Download...[DONE]\t223.64ms\n", + " Load data...[DONE]\t3.95ms\n", + " Tokenize...[DONE]\t31.15ms\n", + " Build vocabulary...[DONE]\t103.48ms\n", + " Prepare text...[DONE]\t381.54ms\n", + " Prepare n_tokens...[DONE]\t388.76ms\n", + " Prepare device...\n", + " Prepare device_info...[DONE]\t68.00ms\n", + " Prepare device...[DONE]\t70.80ms\n", + "Prepare model...[DONE]\t10,879.84ms\n", + "" + ], + "text/plain": [ + "
Prepare mode...[DONE]\t1.50ms\n", + "" + ], + "text/plain": [ + "
\n", + "feedback_transformer: d8eb9416530a11eb8fb50242ac1c0002\n", + "\t[dirty]: \"\"\n", + "Initialize...[DONE]\t1.30ms\n", + "Prepare validator...\n", + " Prepare valid_loader...[DONE]\t70.14ms\n", + "\n", + "--------------------------------------------------\n", + "LABML WARNING\n", + "LabML App Warning: empty_token: Please create a valid token at https://web.lab-ml.com.\n", + "Click on the experiment link to monitor the experiment and add it to your experiments list.\n", + "--------------------------------------------------\n", + "Monitor experiment at https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002\n", + "Prepare validator...[DONE]\t168.60ms\n", + "Prepare trainer...\n", + " Prepare train_loader...[DONE]\t93.88ms\n", + "Prepare trainer...[DONE]\t125.90ms\n", + "Prepare training_loop...\n", + " Prepare loop_count...[DONE]\t35.06ms\n", + "Prepare training_loop...[DONE]\t268.53ms\n", + "It is?aMrPaDPYBnOrPWrrrrrrPBnr\n", + "It is?aMrPaDPYBnrrrrrrrrrrrrrr\n", + "It is?aM ssssssssssosososososo\n", + "It is \n", + "It is \n", + "It is \n", + "It is t t t t t t t t t t t t \n", + "It is an the the the the the t\n", + "It is anour thererererererer t\n", + "It is an the the the the thand\n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is and the the the the the \n", + "It is and the the the the the \n", + "It is and the the the the the \n", + "It is the the the the the the \n", + "It is and the the the the the \n", + "It is an the the the the the t\n", + "It is and the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "1,003,840: Sample: 100% 2,532ms Train: 100% 447,670ms Valid: 100% 19,478ms accuracy.train: 0.239323 loss.train: 2.31443 accuracy.valid: 0.250179 loss.valid: 2.39115 415,476ms 0:06m/ 14:39m \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the theat the theat \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is and the the the the the \n", + "It is the the the the the the \n", + "It is and the the sorengeng th\n", + "It is the the the the the the \n", + "It is and the the sear the the\n", + "It is and the sheall he here h\n", + "It is and the mear the the the\n", + "It is the sould the the sould \n", + "It is a the sore the sore the \n", + "It is a the the the the the th\n", + "It is a the do the to the to t\n", + "It is a the have hath hear his\n", + "It is and the the the the the \n", + "It is a may the sore the the t\n", + "2,007,680: Sample: 100% 2,528ms Train: 100% 445,410ms Valid: 100% 19,593ms accuracy.train: 0.378694 loss.train: 1.96300 accuracy.valid: 0.367457 loss.valid: 2.02065 418,682ms 0:13m/ 14:39m \n", + "It is a the prove the prove th\n", + "It is the stake the stake the \n", + "It is the sore the sore the ha\n", + "It is a the hat her the hath h\n", + "It is the sould the son the wo\n", + "It is me so so the son the son\n", + "It is the cour to the cour to \n", + "It is the will the will the wi\n", + "It is the so the so the seen t\n", + "It is the stall the stall the \n", + "It is the world the world the \n", + "It is the sear the country the\n", + "It is the sould the say the sa\n", + "It is the send the stain the s\n", + "It is the controunter of the c\n", + "It is the rest the course,\n", + "And\n", + "It is the courter of the court\n", + "It is a sonder to the down the\n", + "It is the death of the death o\n", + "It is the sould of the see out\n", + "It is the prove to the course \n", + "It is the have to the world to\n", + "It is the prove the stronger t\n", + "It is the some of the common o\n", + "It is the senter of the senter\n", + "3,011,520: Sample: 100% 2,548ms Train: 100% 445,601ms Valid: 100% 19,525ms accuracy.train: 0.482339 loss.train: 1.35200 accuracy.valid: 0.451241 loss.valid: 1.79424 419,907ms 0:20m/ 14:34m \n", + "It is the see to the sent to t\n", + "It is the stand of the stand o\n", + "It is not the reath,\n", + "That when\n", + "It is the come to the come to \n", + "It is the stand of the senters\n", + "It is the stand of the stand o\n", + "It is the perpereter of the pe\n", + "It is the common of the courte\n", + "It is a stand the world of the\n", + "It is the see of the world of \n", + "It is the sent of the peace of\n", + "It is the stand of the country\n", + "It is the sentreate the stand \n", + "It is the senter of the senter\n", + "It is the country of the sent \n", + "It is the stand the stand of t\n", + "It is not the day the day the \n", + "It is the man of the more of t\n", + "It is the counterness of the c\n", + "It is the lord of the prince o\n", + "It is the foul of the son,\n", + "And\n", + "It is the stree of the days of\n", + "It is the fair of the son of t\n", + "It is the death of the thing t\n", + "It is the death,\n", + "And the that \n", + "4,015,360: Sample: 100% 2,574ms Train: 100% 445,980ms Valid: 100% 19,625ms accuracy.train: 0.532161 loss.train: 1.32791 accuracy.valid: 0.493120 loss.valid: 1.64298 420,841ms 0:28m/ 14:29m \n", + "It is the stronger of the seat\n", + "It is the prince of the prince\n", + "It is the course of the cousin\n", + "It is not so see him and the p\n", + "It is the common of the comman\n", + "It is the state of the time of\n", + "It is the fair of the fairest \n", + "It is not the world of the peo\n", + "It is the senate of the world,\n", + "It is a man and dear the death\n", + "It is the state of the earth,\n", + "\n", + "It is the like of the world wi\n", + "It is the more of the people,\n", + "\n", + "It is the law of son,\n", + "And so t\n", + "It is the more of the common t\n", + "It is the dearer of the perpos\n", + "It is the word of the should b\n", + "It is the last of the world of\n", + "It is the world of the prince \n", + "It is the sentence of the cour\n", + "It is the common of the common\n", + "It is not the world of men and\n", + "It is not the hand of the hous\n", + "4,931,840: Sample: 100% 2,593ms Train: 90% 464,871ms Valid: 86% 12,860ms accuracy.train: 0.555951 loss.train: 1.44885 accuracy.valid: 0.512315 loss.valid: 1.59942 420,841ms 0:34m/ 14:23m \n", + "Saving model...\n", + "Killing loop...\n", + "Still updating LabML App, please wait for it to complete...\n", + "Updating App. Please wait..." + ], + "text/plain": [ + "