diff --git a/labml_nn/hypernetworks/experiment.py b/labml_nn/hypernetworks/experiment.py index eb401ca7..0b3e7c1b 100644 --- a/labml_nn/hypernetworks/experiment.py +++ b/labml_nn/hypernetworks/experiment.py @@ -92,7 +92,6 @@ def main(): # Set models for saving and loading experiment.add_pytorch_models(get_modules(conf)) - conf.init() # Start the experiment with experiment.start(): # `TrainValidConfigs.run` diff --git a/labml_nn/transformers/switch/__init__.py b/labml_nn/transformers/switch/__init__.py index 4ed76321..75ef9e4a 100644 --- a/labml_nn/transformers/switch/__init__.py +++ b/labml_nn/transformers/switch/__init__.py @@ -32,8 +32,8 @@ discusses dropping tokens when routing is not balanced. Here's [the training code](experiment.html) and a notebook for training a switch transformer on Tiny Shakespeare dataset. -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb) -[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/switch/experiment.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=c4656c605b9311eba13d0242ac1c0002) """ import torch diff --git a/labml_nn/transformers/switch/experiment.ipynb b/labml_nn/transformers/switch/experiment.ipynb new file mode 100644 index 00000000..095ea783 --- /dev/null +++ b/labml_nn/transformers/switch/experiment.ipynb @@ -0,0 +1,711 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Switch Transformer", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2" + }, + "source": [ + "[![Github](https://img.shields.io/github/stars/lab-ml/nn?style=social)](https://github.com/lab-ml/nn)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/switch/experiment.ipynb) \n", + "\n", + "## Switch Transformer\n", + "\n", + "This is an experiment training Shakespeare dataset with a small Switch Transformer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9" + }, + "source": [ + "Install the `labml-nn` package" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZCzmCrAIVg0L", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "41bb262e-d7e4-4dd9-cf8c-b2a1724889b7" + }, + "source": [ + "!pip install labml-nn" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting labml-nn\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fa/de/b8bea1493162cc4a845d043a84bce937f68751b9a463b2318672e0a46fed/labml_nn-0.4.79-py3-none-any.whl (117kB)\n", + "\r\u001b[K |██▉ | 10kB 19.5MB/s eta 0:00:01\r\u001b[K |█████▋ | 20kB 13.7MB/s eta 0:00:01\r\u001b[K |████████▍ | 30kB 9.9MB/s eta 0:00:01\r\u001b[K |███████████▏ | 40kB 8.5MB/s eta 0:00:01\r\u001b[K |██████████████ | 51kB 5.5MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 61kB 5.8MB/s eta 0:00:01\r\u001b[K |███████████████████▌ | 71kB 6.2MB/s eta 0:00:01\r\u001b[K |██████████████████████▎ | 81kB 6.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 92kB 6.2MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 102kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▋ | 112kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 122kB 5.3MB/s \n", + "\u001b[?25hCollecting labml-helpers>=0.4.72\n", + " Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n", + "Collecting labml>=0.4.86\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/cd/fb/d11117ce4b2f8bc4e592fada9b416de414e2f566b0b7c10c7304c05908f9/labml-0.4.92-py3-none-any.whl (99kB)\n", + "\u001b[K |████████████████████████████████| 102kB 5.9MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.5)\n", + "Collecting einops\n", + " Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from labml>=0.4.86->labml-nn) (3.13)\n", + "Collecting gitpython\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", + "\u001b[K |████████████████████████████████| 163kB 8.5MB/s \n", + "\u001b[?25hRequirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", + "Collecting gitdb<5,>=4.0.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", + "\u001b[K |████████████████████████████████| 71kB 6.8MB/s \n", + "\u001b[?25hCollecting smmap<4,>=3.0.1\n", + " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", + "Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n", + "Successfully installed einops-0.3.0 gitdb-4.0.5 gitpython-3.1.12 labml-0.4.92 labml-helpers-0.4.72 labml-nn-0.4.79 smmap-3.0.4\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0hJXx_g0wS2C" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from labml import experiment\n", + "from labml.configs import option\n", + "from labml_helpers.module import Module\n", + "from labml_nn.transformers.switch.experiment import Configs" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lpggo0wM6qb-" + }, + "source": [ + "Create an experiment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bFcr9k-l4cAg" + }, + "source": [ + "experiment.create(name=\"switch_transformer\")" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-OnHLi626tJt" + }, + "source": [ + "Initialize configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Piz0c5f44hRo" + }, + "source": [ + "conf = Configs()" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwMzCqpD6vkL" + }, + "source": [ + "Set experiment configurations and assign a configurations dictionary to override configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "e6hmQhTw4nks", + "outputId": "0bc4e738-adc7-4003-a030-4080df882bbb" + }, + "source": [ + "experiment.configs(conf,\n", + " # A dictionary of configurations to override\n", + " {'tokenizer': 'character',\n", + " 'text': 'tiny_shakespeare',\n", + " 'optimizer.learning_rate': 1.,\n", + " 'optimizer.optimizer': 'Noam',\n", + " 'prompt': 'It is',\n", + " 'prompt_separator': '',\n", + "\n", + " 'transformer': 'switch_transformer',\n", + " 'is_scale_prob': False,\n", + " 'n_experts': 4,\n", + "\n", + " 'drop_tokens': True,\n", + " 'capacity_factor': 1.2,\n", + "\n", + " 'train_loader': 'shuffled_train_loader',\n", + " 'valid_loader': 'shuffled_valid_loader',\n", + "\n", + " 'seq_len': 64,\n", + " 'epochs': 128,\n", + " 'batch_size': 32,\n", + " 'inner_iterations': 25,\n", + " })" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
"
+            ],
+            "text/plain": [
+              ""
+            ]
+          },
+          "metadata": {
+            "tags": []
+          }
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "EvI7MtgJ61w5"
+      },
+      "source": [
+        "Set PyTorch models for loading and saving"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 272
+        },
+        "id": "GDlt7dp-5ALt",
+        "outputId": "93e0f3b1-d0fe-4525-d9f6-9ffab9ea7f9b"
+      },
+      "source": [
+        "experiment.add_pytorch_models({'model': conf.model})"
+      ],
+      "execution_count": 6,
+      "outputs": [
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/html": [
+              "
Prepare model...\n",
+              "  Prepare n_tokens...\n",
+              "    Prepare text...\n",
+              "      Prepare tokenizer...[DONE]\t3.00ms\n",
+              "      Download...[DONE]\t347.54ms\n",
+              "      Load data...[DONE]\t7.88ms\n",
+              "      Tokenize...[DONE]\t31.26ms\n",
+              "      Build vocabulary...[DONE]\t93.52ms\n",
+              "    Prepare text...[DONE]\t501.13ms\n",
+              "  Prepare n_tokens...[DONE]\t513.24ms\n",
+              "  Prepare transformer...[DONE]\t108.46ms\n",
+              "  Prepare device...\n",
+              "    Prepare device_info...[DONE]\t68.72ms\n",
+              "  Prepare device...[DONE]\t72.12ms\n",
+              "Prepare model...[DONE]\t10,419.47ms\n",
+              "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KJZRf8527GxL" + }, + "source": [ + "Start the experiment and run the training loop." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "aIAWo7Fw5DR8", + "outputId": "12a92c2e-d248-436b-a6f1-7cf92b5289e9" + }, + "source": [ + "# Start the experiment\n", + "with experiment.start():\n", + " conf.run()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
\n",
+              "switch_transformer: c4656c605b9311eba13d0242ac1c0002\n",
+              "\t[dirty]: \"\"\n",
+              "Initialize...\n",
+              "  Prepare mode...[DONE]\t5.83ms\n",
+              "Initialize...[DONE]\t91.16ms\n",
+              "Prepare validator...\n",
+              "  Prepare valid_loader...[DONE]\t104.91ms\n",
+              "\n",
+              "--------------------------------------------------\n",
+              "LABML WARNING\n",
+              "LabML App Warning: empty_token: Please create a valid token at https://web.lab-ml.com.\n",
+              "Click on the experiment link to monitor the experiment and add it to your experiments list.\n",
+              "--------------------------------------------------\n",
+              "Monitor experiment at https://web.lab-ml.com/run?uuid=c4656c605b9311eba13d0242ac1c0002\n",
+              "Prepare validator...[DONE]\t235.39ms\n",
+              "Prepare trainer...\n",
+              "  Prepare train_loader...[DONE]\t120.81ms\n",
+              "Prepare trainer...[DONE]\t164.43ms\n",
+              "Prepare training_loop...\n",
+              "  Prepare loop_count...[DONE]\t48.22ms\n",
+              "Prepare training_loop...[DONE]\t386.33ms\n",
+              "It isvkTX!oooOkT'HX!vkT'HX!vkT\n",
+              "It isvkT'HXE?L3SXEoooooOkT'HXo\n",
+              "It ise$  ;ES$ ;ES$ ;ES$    c ;\n",
+              "It is        b c  it   c c    \n",
+              "It is  the  the the t  the the\n",
+              "It is the the the the the the \n",
+              "It is we the be the the bethe \n",
+              "It is the the the the the the \n",
+              "It is the the the the the the \n",
+              "It is the the the the the the \n",
+              "It is ind and the the be the i\n",
+              "It iso ino the the the the t t\n",
+              "It iso the the the the the the\n",
+              "It is the anour the the the th\n",
+              "It is theno bou an the t the t\n",
+              "It is beno the t there inouren\n",
+              "It is ben the the inore be t a\n",
+              "It is t the ano the t the t th\n",
+              "It is bsour there therer there\n",
+              "It is thean the be t theande t\n",
+              "It is and the the the the the \n",
+              "It is ing theng theno the ang \n",
+              "It is ithe t bean the the the \n",
+              "It is is the it theno be t the\n",
+              "It is and and the ind the the \n",
+              "1,003,840:  Sample: 100%   567ms  Train: 100%    57,534ms  Valid: 100% 2,134ms   loss.train:  2.37842 accuracy.train: 0.251326 loss.valid:  2.33528 accuracy.valid: 0.255193  72,583ms  0:01m/  2:33m  \n",
+              "It is ingh the at ingour the t\n",
+              "It is ise thet o therere at th\n",
+              "It is isen on the at the the t\n",
+              "It is and theat and the the at\n",
+              "It is the the theat the t at o\n",
+              "It is it and it ithe ive inghe\n",
+              "It is and and the the at the o\n",
+              "It is and the the thee the the\n",
+              "It is ine the the the ine the \n",
+              "It is and and the is the ar in\n",
+              "It is and the the at and t and\n",
+              "It is and the the the the the \n",
+              "It is and and the wend theere \n",
+              "It is and the the athe the wer\n",
+              "It is the at and and of and th\n",
+              "It is inghe t the the steare t\n",
+              "It is and the at the at and th\n",
+              "It is and and and ind and it a\n",
+              "It is in ande to o the ine wit\n",
+              "It is an and and anither the a\n",
+              "It is and ithe in an the with \n",
+              "It is in and ifor of and and i\n",
+              "It is the so set the sof thee \n",
+              "It is and we ith wnd the were \n",
+              "It ise and the are and the are\n",
+              "2,007,680:  Sample: 100%   575ms  Train: 100%    50,720ms  Valid: 100% 2,431ms   loss.train:  1.69642 accuracy.train: 0.371176 loss.valid:  2.11274 accuracy.valid: 0.359420  69,753ms  0:02m/  2:26m  \n",
+              "It is and the ward the ware th\n",
+              "It is and the sear the mean th\n",
+              "It is and to alond the stay al\n",
+              "It is as the willl and and the\n",
+              "It is and the with a with with\n",
+              "It is and and the stand the se\n",
+              "It is a the stare in the of th\n",
+              "It is and iver the with the wi\n",
+              "It is and and an the the words\n",
+              "It is a the ivers and and in t\n",
+              "It is and and the heare seave \n",
+              "It is ame the sone the son the\n",
+              "It is and would:\n",
+              "The wethe the\n",
+              "It is an the see the so to the\n",
+              "It is and the with in the of t\n",
+              "It is a mean the see to the wi\n",
+              "It is am and ither thee would \n",
+              "It is and the with the will th\n",
+              "It is ame the part a the part \n",
+              "It is and and your with the wo\n",
+              "It is are aware the stay to th\n",
+              "It is and the words and the wo\n",
+              "It is will the the some the so\n",
+              "It is have so me to the see to\n",
+              "It is with thee steeent the se\n",
+              "3,011,520:  Sample: 100%   590ms  Train: 100%    51,173ms  Valid: 100% 2,562ms   loss.train:  1.92572 accuracy.train: 0.455776 loss.valid:  1.83140 accuracy.valid: 0.423265  69,136ms  0:03m/  2:24m  \n",
+              "It is andert and the shall and\n",
+              "It is the would the sould the \n",
+              "It is and the standes of the s\n",
+              "It is away are the have art an\n",
+              "It is am be are son the sould \n",
+              "It is the what the be of the t\n",
+              "It is the so mean the some to \n",
+              "It is are with her will the wi\n",
+              "It is and am the with and a th\n",
+              "It is the princes to the soul.\n",
+              "It is if the is a the world th\n",
+              "It is am the wire in the said \n",
+              "It is accannnot thee so tear t\n",
+              "It is and the stay and\n",
+              "That sa\n",
+              "It is them some the wore the w\n",
+              "It is a be the some the word t\n",
+              "It is a be the some the be of \n",
+              "It is thenough the seek of the\n",
+              "It is amper the with the with \n",
+              "It is a may hear her if the so\n",
+              "It is an a my and the world of\n",
+              "It is the steeps the son,\n",
+              "And \n",
+              "It is the single of the shall \n",
+              "It is if revenger the shall be\n",
+              "It is the some of the so the s\n",
+              "4,015,360:  Sample: 100%   641ms  Train: 100%    51,850ms  Valid: 100% 2,669ms   loss.train:  1.65512 accuracy.train: 0.494954 loss.valid:  1.75496 accuracy.valid: 0.455520  69,381ms  0:04m/  2:23m  \n",
+              "It is in the soulding the hand\n",
+              "It is if your hands,\n",
+              "And the w\n",
+              "It is it the worthy the such o\n",
+              "It is bettth of the deatth the\n",
+              "It is the dear the the worse t\n",
+              "It is in the the sould of the \n",
+              "It is the say shall be the wor\n",
+              "It is in the see for the world\n",
+              "It is in the soul have the sou\n",
+              "It is poid your have your be t\n",
+              "It is the have have have have \n",
+              "It is the have have seen the s\n",
+              "It is the soul half to the wor\n",
+              "It is the some the soul so sou\n",
+              "It is the will see the shall o\n",
+              "It is and but the soul so be t\n",
+              "It is the have been her so bea\n",
+              "It is enter the shall be so sh\n",
+              "It is and the so the sorrow of\n",
+              "It is the was the son the worl\n",
+              "It is not the word:\n",
+              "The will s\n",
+              "It is thee repeit the words of\n",
+              "It is an the world of him the \n",
+              "It is all be the would have th\n",
+              "It is the be into enter the se\n",
+              "5,019,200:  Sample: 100%   661ms  Train: 100%    53,431ms  Valid: 100% 2,854ms   loss.train:  1.61666 accuracy.train: 0.518118 loss.valid:  1.60673 accuracy.valid: 0.478114  70,052ms  0:05m/  2:23m  \n",
+              "It is and the seem of remain o\n",
+              "It is an her be the hand tear \n",
+              "It is and the shall be the sha\n",
+              "It is an the word the word of \n",
+              "It is all be so mean and a for\n",
+              "It is a by your to heart,\n",
+              "And \n",
+              "It is a better the stands of t\n",
+              "It is am in the shall of the r\n",
+              "It is a many hire son the sent\n",
+              "It is bettey your to be the so\n",
+              "It is an the seat of the sound\n",
+              "It is the comfort:\n",
+              "The soul sh\n",
+              "It is the stand the graving of\n",
+              "It is the tell terms the wors \n",
+              "It is an the such of the such \n",
+              "It is all the sea the still be\n",
+              "It is as the seem of the seem \n",
+              "It is a priced in the son of t\n",
+              "It is a man the faither of con\n",
+              "It is the see for the sentry t\n",
+              "It is a dead iver to the sent\n",
+              "\n",
+              "It is be words the speak of th\n",
+              "It is an the seen of the saits\n",
+              "It is not be so more the for t\n",
+              "It is and the counter to the s\n",
+              "6,023,040:  Sample: 100%   730ms  Train: 100%    56,863ms  Valid: 100% 3,057ms   loss.train:  1.40999 accuracy.train: 0.536962 loss.valid:  1.49942 accuracy.valid: 0.497049  71,344ms  0:07m/  2:25m  \n",
+              "It is my lord:\n",
+              "The shall be sh\n",
+              "It is the shall of the sound o\n",
+              "It is and the soul have should\n",
+              "It is the son the souls be som\n",
+              "It is a the world in the remor\n",
+              "It is a fare with the rest,\n",
+              "An\n",
+              "It isablel, the speak of the s\n",
+              "It istern the world of the wor\n",
+              "It is a my love,\n",
+              "Thy sound and\n",
+              "It is the dower in ease of the\n",
+              "It is the seems of the sun of \n",
+              "It is the world if your son,\n",
+              "A\n",
+              "It is and in the shall be the \n",
+              "It is and the word of your fai\n",
+              "It is a the wardon in the sena\n",
+              "It is not the father of the de\n",
+              "It is be the sent of the souls\n",
+              "It is a man the shall of the s\n",
+              "It is and with the silent of h\n",
+              "It is not so more for his most\n",
+              "It is be  affecter you to my l\n",
+              "It is an my lord.\n",
+              "\n",
+              "First Servi\n",
+              "It is a make and the sense of \n",
+              "It is a man the world of the s\n",
+              "It is a many all that you fair\n",
+              "7,026,880:  Sample: 100%   727ms  Train: 100%    55,166ms  Valid: 100% 3,062ms   loss.train:  1.37415 accuracy.train: 0.548260 loss.valid:  1.72075 accuracy.valid: 0.506108  72,188ms  0:08m/  2:25m  \n",
+              "It is the still of the siver a\n",
+              "It is a me the prove of the se\n",
+              "It is a man that the sent of t\n",
+              "It is a speak in the subjects \n",
+              "It is the shallow of your hous\n",
+              "It is the world if the souls o\n",
+              "It is a prinches and the senat\n",
+              "It is a many of the former of \n",
+              "It is an the sea wife and the \n",
+              "It is a for the serving of his\n",
+              "It is a straw of the straight \n",
+              "It is and your grace.\n",
+              "\n",
+              "LADY CA\n",
+              "It is a many of your father on\n",
+              "It is the substiness of the so\n",
+              "It is all be in the sense to t\n",
+              "It is unto the such offer of t\n",
+              "It is a man the father off the\n",
+              "It is not the world of the sta\n",
+              "It is a man the rest after of \n",
+              "It is not of the say officer\n",
+              "T\n",
+              "It is the seems of the rest of\n",
+              "It is a man the subjection of \n",
+              "It is by the shall be son to t\n",
+              "It is the plack of the poor of\n",
+              "It is any a farewell to be so \n",
+              "8,030,720:  Sample: 100%   752ms  Train: 100%    56,920ms  Valid: 100% 3,210ms   loss.train:  1.39653 accuracy.train: 0.557187 loss.valid:  1.66937 accuracy.valid: 0.513688  73,094ms  0:09m/  2:26m  \n",
+              "It is a  man the seal of the s\n",
+              "It is a the reason of the seas\n",
+              "It is believe the earth a man \n",
+              "It is a prity the prince of hi\n",
+              "It is the part of heaven the s\n",
+              "It is not soul a prity to the \n",
+              "It is there: thy house that sh\n",
+              "It is a the better the world o\n",
+              "It is an the seems and the sep\n",
+              "It is the sall strike of the s\n",
+              "It is now the fair of his and \n",
+              "It is a man of your tongue to \n",
+              "It is the proper of the seast \n",
+              "It is a with hire if resolved \n",
+              "It is a man that of the royal \n",
+              "It is not the some offend the \n",
+              "It is the world in the sun of \n",
+              "It is a prince shall be seen t\n",
+              "It is a quent, and the seems o\n",
+              "It is an the rest of the resti\n",
+              "It is a the rest of the rest.\n",
+              "\n",
+              "It is an the world of your han\n",
+              "It is the world in the reasons\n",
+              "It is a  prest as a fair there\n",
+              "It is not one that be the stan\n",
+              "9,034,560:  Sample: 100%   773ms  Train: 100%    56,862ms  Valid: 100% 3,326ms   loss.train:  1.49838 accuracy.train: 0.564164 loss.valid:  1.70500 accuracy.valid: 0.517500  73,952ms  0:10m/  2:26m  \n",
+              "It is a man the prison,\n",
+              "And th\n",
+              "It is undertain the straight o\n",
+              "It is an that you said some so\n",
+              "It is loved for his father of \n",
+              "It is a many accoursed to me t\n",
+              "It is and the with the subject\n",
+              "It is the sun any the seat of \n",
+              "It is us the world of the worl\n",
+              "It is not the son,\n",
+              "And the sen\n",
+              "It is the regal of the seas of\n",
+              "It islent the seal of the seas\n",
+              "It is the warrant the power of\n",
+              "It is a sped to the state of t\n",
+              "It is an the sad makes of the \n",
+              "It is now the some fortune of \n",
+              "It is the constrant of the sea\n",
+              "It is a may ready be so for th\n",
+              "It is and the best of expedity\n",
+              "It is there is the rest:\n",
+              "The s\n",
+              "It is a part in the world in t\n",
+              "It is a that your stay in the \n",
+              "It is the warlize son,\n",
+              "And wha\n",
+              "It is a man time that you have\n",
+              "It isle there in the substanti\n",
+              "It is a stand to make her soul\n",
+              "10,038,400:  Sample: 100%   812ms  Train: 100%    58,431ms  Valid: 100% 3,565ms   loss.train:  1.32136 accuracy.train: 0.569829 loss.valid:  1.57321 accuracy.valid: 0.522639  74,977ms  0:12m/  2:27m  \n",
+              "It is not report the son,\n",
+              "And \n",
+              "It is unto the ssuch and the s\n",
+              "It is a many a thoughts of the\n",
+              "It is not the world of our hea\n",
+              "It is not the stir,\n",
+              "The still \n",
+              "It is a words the county.\n",
+              "\n",
+              "Clo\n",
+              "It is a done that you govern'd\n",
+              "It is an quiter and the state,\n",
+              "It is and my son of the sun an\n",
+              "It is and the restigninesss of\n",
+              "It is a bear your son:\n",
+              "The sen\n",
+              "It is an the range of the stat\n",
+              "It is an the and the faar if r\n",
+              "It is an the stand for the rea\n",
+              "It is and ended so much a powe\n",
+              "It is a man one that the world\n",
+              "It is the wars of the son and \n",
+              "It is the world of the world.\n",
+              "\n",
+              "It is thalt youth sent the sea\n",
+              "It is note the same of the sou\n",
+              "It is a my soul, and the senat\n",
+              "It is then the showers that wi\n",
+              "It is the worn if you have bee\n",
+              "It is the world the restign of\n",
+              "It is and all unto the rest,\n",
+              "T\n",
+              "11,042,240:  Sample: 100%   856ms  Train: 100%    58,542ms  Valid: 100% 3,513ms   loss.train:  1.29441 accuracy.train: 0.575404 loss.valid:  1.53764 accuracy.valid: 0.524586  76,085ms  0:13m/  2:28m  \n",
+              "It is an the rest,\n",
+              "And then th\n",
+              "It is an thou restraing:\n",
+              "The s\n",
+              "It is a man of off?\n",
+              "\n",
+              "PETER:\n",
+              "I \n",
+              "It is thy soul be so for the s\n",
+              "It is the acking of the sease \n",
+              "It is the world of a bark that\n",
+              "It is the world so languared t\n",
+              "It is the words of ears,\n",
+              "And t\n",
+              "It is the day of the rank of t\n",
+              "It is and the wars the strange\n",
+              "It is a man to tell the world \n",
+              "It is then the saints and the \n",
+              "It is learn the rascal of the \n",
+              "It is a prized in the old and \n",
+              "It is and the rest of reasons,\n",
+              "It is looked to the sea thou s\n",
+              "It is and the dead of our stro\n",
+              "It is a spettrance if the seal\n",
+              "It is and the sake the season\n",
+              "\n",
+              "It is the still to the salute \n",
+              "It is the way as eat your hous\n",
+              "It is the point of the world o\n",
+              "It is a part to make the senat\n",
+              "It is not the son of the son.\n",
+              "\n",
+              "It is a man of our little upon\n",
+              "12,046,080:  Sample: 100%   878ms  Train: 100%    60,251ms  Valid: 100% 3,737ms   loss.train:  1.43067 accuracy.train: 0.579984 loss.valid:  1.71077 accuracy.valid: 0.528093  77,279ms  0:15m/  2:29m  \n",
+              "It is an the achires of the se\n",
+              "It is Has is a doubt and the s\n",
+              "It is a parting to the senate.\n",
+              "It is the senter of the sea,\n",
+              "A\n",
+              "It is the wars of your head,\n",
+              "A\n",
+              "It is all the sentence of the \n",
+              "It is not the pardon of the se\n",
+              "It is a seen the speks of the \n",
+              "It is the seeks of the souls o\n",
+              "It is aweinger in the state of\n",
+              "It is a paith in the compt of \n",
+              "It is an the state of the sena\n",
+              "12,519,168:  Sample: 100%   901ms  Train:  47%    59,552ms  Valid:  43% 3,543ms   loss.train:  1.35161 accuracy.train: 0.585834 loss.valid:  1.68848 accuracy.valid: 0.530090  77,279ms  0:15m/  2:29m  
" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "oBXXlP2b7XZO" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/labml_nn/transformers/switch/experiment.py b/labml_nn/transformers/switch/experiment.py index f0a1fc17..162593e9 100644 --- a/labml_nn/transformers/switch/experiment.py +++ b/labml_nn/transformers/switch/experiment.py @@ -1,3 +1,14 @@ +""" +--- +title: Switch Transformer Experiment +summary: This experiment trains a small switch transformer on tiny Shakespeare dataset. +--- + +# Switch Transformer Experiment + +This is an annotated PyTorch experiment to train a switch transformer. +""" + import torch import torch.nn as nn @@ -17,19 +28,25 @@ class AutoregressiveModel(Module): super().__init__() # Token embedding module self.src_embed = nn.Embedding(n_vocab, d_model) + # Transformer self.transformer = transformer + # Final layer self.generator = nn.Linear(d_model, n_vocab) self.mask = None def __call__(self, x: torch.Tensor): + # Initialize the subsequent mask if self.mask is None or self.mask.size(0) != len(x): from labml_nn.transformers.utils import subsequent_mask self.mask = subsequent_mask(len(x)).to(x.device) + # Token embeddings x = self.src_embed(x) - # Embed the tokens (`src`) and run it through the the transformer + # Run it through the transformer res, counts, route_prob, n_dropped = self.transformer(x, self.mask) # Generate logits of the next token - return self.generator(res), counts, route_prob, n_dropped + res = self.generator(res) + # + return res, counts, route_prob, n_dropped class Configs(NLPAutoRegressionConfigs): @@ -42,19 +59,30 @@ class Configs(NLPAutoRegressionConfigs): model: AutoregressiveModel transformer: Module + # Token embedding size d_model: int = 128 + # Number of attention heads heads: int = 4 + # Dropout probability dropout: float = 0.0 + # Number of features in FFN hidden layer d_ff: int = 256 + # Number of transformer layers n_layers: int = 6 + # Number of experts n_experts: int = 4 + # Load balancing coefficient load_balancing_loss_ceof = 0.01 + # Whether to scale the chosen expert outputs by the routing probability is_scale_prob: bool = True + # Whether to drop tokens drop_tokens: bool = False + # Capacity factor to determine capacity of each model capacity_factor: float = 1.0 def init(self): super().init() + # Initialize tracking indicators tracker.set_scalar("lb_loss.*", False) tracker.set_scalar("route.*", False) tracker.set_scalar("dropped.*", False) @@ -74,28 +102,35 @@ class Configs(NLPAutoRegressionConfigs): # Whether to capture model outputs with self.mode.update(is_log_activations=batch_idx.is_last): # Get model outputs. - # It's returning a tuple for states when using RNNs. - # This is not implemented yet. 😜 output, counts, route_prob, n_dropped = self.model(data) - # Calculate and log loss - loss = self.loss_func(output, target) + # Calculate and cross entropy loss + cross_entropy_loss = self.loss_func(output, target) + # Total number of tokens processed, $T$, in the current batch $\mathscr{B}$ total = counts.sum(dim=-1, keepdims=True) + # Fraction of tokens routed to each expert # $$f_i = \frac{1}{T} \sum_{x \in \mathscr{B}} \unicode{x1D7D9} \{ \mathop{argmax} p(x), i \}$$ - # where $\mathscr{B}$ is the batch and $T$ is the number of tokens in the batch. # $f_i$ is the count of tokens where the argmax of $p(x)$ is equal to $i$. route_frac = counts / total + # Mean routing probability + # $$P_i = \frac{1}{T} \sum_{x \in \mathscr{B}} p_i (x)$$ route_prob = route_prob / total + # Load balancing loss + # $$\mathscr{L} = N \sum_{i=1}^N f_i \cdot P_i$$ + load_balancing_loss = self.n_experts * (route_frac * route_prob).sum() + + # Track stats tracker.add('dropped.', total.new_tensor(n_dropped) / total) tracker.add('route.min.', route_frac.min()) tracker.add('route.max.', route_frac.max()) tracker.add('route.std.', route_frac.std()) - # for i in range(self.n_switches): - # tracker.add(f'route.{i}', route_frac[:, i].mean()) - load_balancing_loss = self.n_experts * (route_frac * route_prob).sum() - tracker.add("loss.", loss) - tracker.add("lb_loss.", loss) - loss = loss + self.load_balancing_loss_ceof * load_balancing_loss + tracker.add("loss.", cross_entropy_loss) + tracker.add("lb_loss.", load_balancing_loss) + + # Combined loss. + # The load balancing loss is multiplied by a coefficient $\alpha$ which is + # set to something small like $\alpha = 0.01$. + loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss # Calculate and log accuracy self.accuracy(output, target) @@ -121,12 +156,18 @@ class Configs(NLPAutoRegressionConfigs): @option(Configs.model) def autoregressive_model(c: Configs): + """ + ### Initialize the auto-regressive model + """ m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer) return m.to(c.device) @option(Configs.transformer) def switch_transformer(c: Configs): + """ + ### Initialize the switch transformer + """ from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward from labml_nn.transformers import MultiHeadAttention @@ -145,6 +186,9 @@ def switch_transformer(c: Configs): def main(): + """ + ### Run the experiment + """ # Create experiment experiment.create(name="switch_transformer", comment='') # Create configs @@ -172,17 +216,18 @@ def main(): 'seq_len': 64, 'epochs': 128, 'batch_size': 32, - 'inner_iterations': 25}) + 'inner_iterations': 25, + }) # Set models for saving and loading experiment.add_pytorch_models({'model': conf.model}) - conf.init() # Start the experiment with experiment.start(): # `TrainValidConfigs.run` conf.run() +# if __name__ == '__main__': main() diff --git a/setup.py b/setup.py index b2a96023..343cde5e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("readme.md", "r") as f: setuptools.setup( name='labml-nn', - version='0.4.78', + version='0.4.79', author="Varuna Jayasiri, Nipun Wijerathne", author_email="vpjayasiri@gmail.com, hnipun@gmail.com", description="A collection of PyTorch implementations of neural network architectures and layers.",