diff --git a/labml_nn/transformers/gpt/__init__.py b/labml_nn/transformers/gpt/__init__.py index 4bef688b..ade65bac 100644 --- a/labml_nn/transformers/gpt/__init__.py +++ b/labml_nn/transformers/gpt/__init__.py @@ -24,6 +24,11 @@ Main differences of this to a standard autoregressive transformer are the parameter initialization, weight decay, and learning rate schedule. For the transformer we reuse the [existing labml/nn transformer implementation](https://lab-ml.com/labml_nn/transformers/). + +Here's a notebook for training a GPT mode on Tiny Shakespeare dataset. + +[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/gpt/experiment.ipynb) +[](https://web.lab-ml.com/run?uuid=0324c6d0562111eba65d0242ac1c0002) """ import torch diff --git a/labml_nn/transformers/gpt/experiment.ipynb b/labml_nn/transformers/gpt/experiment.ipynb new file mode 100644 index 00000000..459aa10d --- /dev/null +++ b/labml_nn/transformers/gpt/experiment.ipynb @@ -0,0 +1,624 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "GPT", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2" + }, + "source": [ + "[](https://github.com/lab-ml/nn)\n", + "[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/gpt/experiment.ipynb) \n", + "\n", + "## Training a model with GPT architecture\n", + "\n", + "This is an experiment training Tiny Shakespeare dataset with GPT architecture model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9" + }, + "source": [ + "Install the `labml-nn` package" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZCzmCrAIVg0L", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "9c544df4-3fc7-4152-b50d-07919dbfb9de" + }, + "source": [ + "!pip install labml-nn" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting labml-nn\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f5/92/c454c38d613449e9cfee59809b83589bfc5463ebcf39a72126c268e31a77/labml_nn-0.4.78-py3-none-any.whl (111kB)\n", + "\u001b[K |████████████████████████████████| 112kB 8.4MB/s \n", + "\u001b[?25hCollecting labml>=0.4.86\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a7/d3/f8708934e0062e6403faa2a36d97e1677097740c94f90fd7c04ea986d7cf/labml-0.4.89-py3-none-any.whl (97kB)\n", + "\u001b[K |████████████████████████████████| 102kB 6.1MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.5)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n", + "Collecting einops\n", + " Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", + "Collecting labml-helpers>=0.4.72\n", + " Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n", + "Collecting pyyaml>=5.3.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", + "\u001b[K |████████████████████████████████| 276kB 11.7MB/s \n", + "\u001b[?25hCollecting gitpython\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", + "\u001b[K |████████████████████████████████| 163kB 13.8MB/s \n", + "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n", + "Collecting gitdb<5,>=4.0.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", + "\u001b[K |████████████████████████████████| 71kB 8.7MB/s \n", + "\u001b[?25hCollecting smmap<4,>=3.0.1\n", + " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", + "Building wheels for collected packages: pyyaml\n", + " Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=8a5f1cee9425b47a5677f6229a8cdaea31a4c0c52473052981369bc0f865efe2\n", + " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n", + "Successfully built pyyaml\n", + "Installing collected packages: pyyaml, smmap, gitdb, gitpython, labml, einops, labml-helpers, labml-nn\n", + " Found existing installation: PyYAML 3.13\n", + " Uninstalling PyYAML-3.13:\n", + " Successfully uninstalled PyYAML-3.13\n", + "Successfully installed einops-0.3.0 gitdb-4.0.5 gitpython-3.1.12 labml-0.4.89 labml-helpers-0.4.72 labml-nn-0.4.78 pyyaml-5.3.1 smmap-3.0.4\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0hJXx_g0wS2C" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from labml import experiment\n", + "from labml.configs import option\n", + "from labml_helpers.module import Module\n", + "from labml_nn.transformers.gpt import Configs" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lpggo0wM6qb-" + }, + "source": [ + "Create an experiment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bFcr9k-l4cAg" + }, + "source": [ + "experiment.create(name=\"gpt\")" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-OnHLi626tJt" + }, + "source": [ + "Initialize [GPT configurations](https://lab-ml.com/labml_nn/transformers/gpt/)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Piz0c5f44hRo" + }, + "source": [ + "conf = Configs()" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwMzCqpD6vkL" + }, + "source": [ + "Set experiment configurations and assign a configurations dictionary to override configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "e6hmQhTw4nks", + "outputId": "018b39d7-7d84-4651-8b33-80de77d58ace" + }, + "source": [ + "experiment.configs(conf, {\n", + " # Use character level tokenizer\n", + " 'tokenizer': 'character',\n", + " # Prompt separator is blank\n", + " 'prompt_separator': '',\n", + " # Starting prompt for sampling\n", + " 'prompt': 'It is ',\n", + " # Use Tiny Shakespeare dataset\n", + " 'text': 'tiny_shakespeare',\n", + "\n", + " # Use a context size of $128$\n", + " 'seq_len': 128,\n", + " # Train for $32$ epochs\n", + " 'epochs': 32,\n", + " # Batch size $128$\n", + " 'batch_size': 128,\n", + " # Switch between training and validation for $10$ times\n", + " # per epoch\n", + " 'inner_iterations': 10,\n", + "\n", + " # Transformer configurations\n", + " 'transformer.d_model': 512,\n", + " 'transformer.d_ff': 2048,\n", + " 'transformer.n_heads': 8,\n", + " 'transformer.n_layers': 6\n", + "})" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "Prepare model...\n", + " Prepare transformer...\n", + " Prepare n_tokens...\n", + " Prepare text...\n", + " Prepare tokenizer...[DONE]\t4.91ms\n", + " Download...[DONE]\t261.22ms\n", + " Load data...[DONE]\t7.26ms\n", + " Tokenize...[DONE]\t27.96ms\n", + " Build vocabulary...[DONE]\t137.31ms\n", + " Prepare text...[DONE]\t458.67ms\n", + " Prepare n_tokens...[DONE]\t465.05ms\n", + " Prepare transformer...[DONE]\t472.81ms\n", + " Prepare encoder...\n", + " Prepare encoder_layer...\n", + " Prepare encoder_attn...[DONE]\t46.65ms\n", + " Prepare feed_forward...\n", + " Prepare feed_forward_activation...[DONE]\t4.62ms\n", + " Prepare feed_forward...[DONE]\t25.12ms\n", + " Prepare encoder_layer...[DONE]\t115.91ms\n", + " Prepare encoder...[DONE]\t165.88ms\n", + " Prepare src_embed...[DONE]\t76.85ms\n", + " Prepare generator...[DONE]\t3.92ms\n", + " Prepare device...\n", + " Prepare device_info...[DONE]\t63.02ms\n", + " Prepare device...[DONE]\t66.20ms\n", + "Prepare model...[DONE]\t11,075.54ms\n", + "" + ], + "text/plain": [ + "
\n", + "gpt: 0324c6d0562111eba65d0242ac1c0002\n", + "\t[dirty]: \"\"\n", + "Initialize...\n", + " Prepare mode...[DONE]\t2.91ms\n", + "Initialize...[DONE]\t105.13ms\n", + "Prepare validator...\n", + " Prepare valid_loader...[DONE]\t112.67ms\n", + "Prepare validator...[DONE]\t259.72ms\n", + "\n", + "--------------------------------------------------\n", + "LABML WARNING\n", + "LabML App Warning: empty_token: Please create a valid token at https://web.lab-ml.com.\n", + "Click on the experiment link to monitor the experiment and add it to your experiments list.\n", + "--------------------------------------------------\n", + "Monitor experiment at https://web.lab-ml.com/run?uuid=0324c6d0562111eba65d0242ac1c0002\n", + "\n", + " Prepare train_loader...[DONE]\t105.04ms\n", + "Prepare trainer...[DONE]\t207.61ms\n", + "Prepare training_loop...\n", + " Prepare loop_count...[DONE]\t51.30ms\n", + "Prepare training_loop...[DONE]\t374.91ms\n", + "It is LmBsuuwsdtUUQUHHHs\n", + "\n", + "\n", + "BsYh\n", + "It is s s an s t a a a s t the \n", + "It is t the t the t t t t t the\n", + "It is the the the the the t the\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the and the and and the a\n", + "1,003,776: Sample: 100% 303ms Train: 100% 53,938ms Valid: 100% 1,973ms loss.train: 2.38163 accuracy.train: 0.250702 loss.valid: 2.36796 accuracy.valid: 0.269186 61,419ms 0:01m/ 0:31m \n", + "It is the an the the the the th\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "It is man the the the the the t\n", + "It is an the the the the the th\n", + "It is the the the the the the t\n", + "It is the the the the the the t\n", + "2,007,552: Sample: 100% 320ms Train: 100% 50,987ms Valid: 100% 2,146ms loss.train: 2.11509 accuracy.train: 0.343499 loss.valid: 2.17901 accuracy.valid: 0.343077 59,401ms 0:01m/ 0:29m \n", + "It is the to the the the the th\n", + "It is the to the the the the th\n", + "It is and the the the the the t\n", + "It is and the thee thee the the\n", + "It is and the the the the she s\n", + "It is the so the so dee the son\n", + "It is to the the with the and t\n", + "It is and the so sould the soul\n", + "It is and the the thee strees t\n", + "It is the the with so the with \n", + "3,011,328: Sample: 100% 326ms Train: 100% 51,276ms Valid: 100% 2,203ms loss.train: 1.86443 accuracy.train: 0.415378 loss.valid: 1.99342 accuracy.valid: 0.396061 58,651ms 0:02m/ 0:28m \n", + "It is the the would the with th\n", + "It is the shall so thee the sha\n", + "It is the the the the the the t\n", + "It is the shall the the come th\n", + "It is the the his shall the sha\n", + "It is the sould the storrow the\n", + "It is the soul the word the so \n", + "It is the so so me to mee the s\n", + "It is the so the stay the shall\n", + "It is the so the so the so the \n", + "4,015,104: Sample: 100% 341ms Train: 100% 51,141ms Valid: 100% 2,202ms loss.train: 1.74658 accuracy.train: 0.464521 loss.valid: 1.86725 accuracy.valid: 0.433598 58,268ms 0:03m/ 0:27m \n", + "It is the so so he her heart th\n", + "It is the the counter the so th\n", + "It is the seep the word the wor\n", + "It is and the cannot to the can\n", + "It is a the shall the was the d\n", + "It is a say a say a so the soun\n", + "It is the shall the hands and t\n", + "It is the sounder the shall the\n", + "It is the shall be the dear to \n", + "It is a the come to the sould o\n", + "5,018,880: Sample: 100% 351ms Train: 100% 51,208ms Valid: 100% 2,213ms loss.train: 1.67815 accuracy.train: 0.499096 loss.valid: 1.76973 accuracy.valid: 0.463422 58,122ms 0:04m/ 0:26m \n", + "It is the have the be the have \n", + "It is the so so so him him him \n", + "It is the the send the will the\n", + "It is the soul to the the the c\n", + "It is the stand the come of the\n", + "It is the heart the so the see \n", + "It is the would the world the w\n", + "It is the served the service th\n", + "It is the shall the shout shalt\n", + "It is the should be so the for \n", + "6,022,656: Sample: 100% 354ms Train: 100% 50,986ms Valid: 100% 2,220ms loss.train: 1.54598 accuracy.train: 0.521924 loss.valid: 1.66606 accuracy.valid: 0.481630 57,965ms 0:05m/ 0:25m \n", + "It is the see the sent the to t\n", + "It is the her shall the world.\n", + "\n", + "It is a strain.\n", + "\n", + "COMINIUS:\n", + "I wi\n", + "It is the should the hand the b\n", + "It is the sent of the shall be \n", + "It is the sentery the stander t\n", + "It is the shall be the see the \n", + "It is not to the should be so s\n", + "It is the such a man the world \n", + "It is the shall be the shall be\n", + "7,026,432: Sample: 100% 360ms Train: 100% 51,237ms Valid: 100% 2,235ms loss.train: 1.46728 accuracy.train: 0.540290 loss.valid: 1.66135 accuracy.valid: 0.495802 57,923ms 0:06m/ 0:24m \n", + "It is not the son of the see th\n", + "It is then the should of the co\n", + "It is not the so so so my son\n", + "A\n", + "It is the world to the world of\n", + "It is the send the death of the\n", + "It is the present the to mad be\n", + "It is the send the to him the h\n", + "It is the profater of the sea s\n", + "It is not the shall be strike o\n", + "It is the some to the some of t\n", + "8,030,208: Sample: 100% 371ms Train: 100% 51,149ms Valid: 100% 2,215ms loss.train: 1.42535 accuracy.train: 0.554010 loss.valid: 1.63348 accuracy.valid: 0.507633 57,901ms 0:07m/ 0:23m \n", + "It is the come of the sons,\n", + "The\n", + "It is not the strike of the sou\n", + "It is not so so so make the see\n", + "It is not the stand of the stan\n", + "It is the so offend the seat of\n", + "It is the strike of the strike \n", + "It is the country the worling o\n", + "It is the strike a man and the \n", + "It is the strike a more of the \n", + "It is not the season of the way\n", + "9,033,984: Sample: 100% 382ms Train: 100% 51,031ms Valid: 100% 2,218ms loss.train: 1.39141 accuracy.train: 0.565689 loss.valid: 1.65321 accuracy.valid: 0.512145 57,884ms 0:08m/ 0:22m \n", + "It is not so may so more to sta\n", + "It is the soul be so man of the\n", + "It is not so so so so dead and \n", + "It is not so much a so so so so\n", + "It is the death of the comes an\n", + "It is the desire the death,\n", + "And\n", + "It is not so more that thou was\n", + "It is not so so much a some to \n", + "It is the son of the souls of t\n", + "It is not the senate of the cou\n", + "10,037,760: Sample: 100% 394ms Train: 100% 51,195ms Valid: 100% 2,231ms loss.train: 1.33575 accuracy.train: 0.575717 loss.valid: 1.61892 accuracy.valid: 0.519455 57,919ms 0:09m/ 0:21m \n", + "It is the fair of the season.\n", + "\n", + "\n", + "It is not the comple of the cou\n", + "It is not the senate of the cou\n", + "It is the straight and the worl\n", + "It is a man a more of that have\n", + "It is the man of my soul stand \n", + "It is the senators of the state\n", + "It is the content the seal of t\n", + "It is the content to the provok\n", + "It is the sense of the prince o\n", + "11,041,536: Sample: 100% 402ms Train: 100% 51,201ms Valid: 100% 2,238ms loss.train: 1.32079 accuracy.train: 0.585922 loss.valid: 1.54672 accuracy.valid: 0.521141 57,958ms 0:10m/ 0:20m \n", + "It is not to see him to him.\n", + "\n", + "L\n", + "It is not the consent to the co\n", + "It is the death of the day of t\n", + "It is the soon of the court of \n", + "It is the senators of the strea\n", + "It is the compless of the world\n", + "It is not the command of my sou\n", + "It is the senate of the state o\n", + "It is not a sorrraw to him.\n", + "\n", + "AU\n", + "It is not the sound of the cour\n", + "12,045,312: Sample: 100% 418ms Train: 100% 51,201ms Valid: 100% 2,240ms loss.train: 1.29904 accuracy.train: 0.594461 loss.valid: 1.58210 accuracy.valid: 0.524935 58,013ms 0:11m/ 0:19m \n", + "It is not so man is my soul sou\n", + "It is the strong and the straig\n", + "It is the world the field of th\n", + "It is the death of the state,\n", + "A\n", + "It is not the sorrow of the dea\n", + "It is the strange the sea of th\n", + "It is the present of the house \n", + "It is not the son of the should\n", + "It is the contract that he will\n", + "It is not the seas of the monta\n", + "13,049,088: Sample: 100% 437ms Train: 100% 51,117ms Valid: 100% 2,237ms loss.train: 1.26697 accuracy.train: 0.603765 loss.valid: 1.57502 accuracy.valid: 0.526765 58,092ms 0:12m/ 0:18m \n", + "It is not the man of the man of\n", + "It is not the stand of the cour\n", + "It is the prisoner of the head,\n", + "It is not the singled of his so\n", + "It is a man the heavens of the \n", + "It is not the stands of the sid\n", + "It is not the senators of the c\n", + "It is the sense that the presen\n", + "It is the dead of the court of \n", + "It is not the common of the com\n", + "14,052,864: Sample: 100% 443ms Train: 100% 51,152ms Valid: 100% 2,223ms loss.train: 1.21225 accuracy.train: 0.612812 loss.valid: 1.58159 accuracy.valid: 0.528407 58,138ms 0:13m/ 0:17m \n", + "It is not to the senate of the \n", + "It is not but the duke of death\n", + "It is not the stand of the stat\n", + "It is not the season of the par\n", + "It is not the langue of the fat\n", + "It is not a subject to the coul\n", + "It is not the duke of the death\n", + "It is not the side of the state\n", + "It is not the standed to the st\n", + "It is not the duke of the duke \n", + "15,056,640: Sample: 100% 452ms Train: 100% 51,115ms Valid: 100% 2,240ms loss.train: 1.22173 accuracy.train: 0.622129 loss.valid: 1.58946 accuracy.valid: 0.527285 58,204ms 0:14m/ 0:16m \n", + "It is not so sorrried to him ha\n", + "It is not a man of the prince o\n", + "It is not to the beloned and de\n", + "It is not the sorrow of my soul\n", + "It is the seated of the death a\n", + "It is not the sorrow of your ho\n", + "It is not the stand of my son,\n", + "\n", + "It is the sense of the deep of \n", + "It is not the soldiers of the h\n", + "It is the conqueror of the deat\n", + "16,060,416: Sample: 100% 457ms Train: 100% 51,418ms Valid: 100% 2,239ms loss.train: 1.17331 accuracy.train: 0.630410 loss.valid: 1.63163 accuracy.valid: 0.527393 58,299ms 0:15m/ 0:15m \n", + "It is the world and the sea of \n", + "It is not the court of the coll\n", + "It is not to the more of the fa\n", + "It is not the sore of the more \n", + "It is a sorrow in a last a grea\n", + "It is not the sun that have mad\n", + "It is not to see him to him to \n", + "It is not the sun of the deed i\n", + "It is the dead?\n", + "\n", + "DUKE OF YORK:\n", + "\n", + "It is not the soldiers of the c\n", + "17,064,192: Sample: 100% 480ms Train: 100% 51,259ms Valid: 100% 2,239ms loss.train: 1.16549 accuracy.train: 0.639204 loss.valid: 1.60157 accuracy.valid: 0.527572 58,395ms 0:16m/ 0:14m \n", + "It is the foot of all,\n", + "That I m\n", + "It is not the sea that the man \n", + "It is the deep and death death \n", + "It is not the sea,\n", + "That make th\n", + "It is not to be a power of the \n", + "17,522,944: Sample: 100% 491ms Train: 43% 52,020ms Valid: 28% 2,205ms loss.train: 1.12190 accuracy.train: 0.650947 loss.valid: 1.63582 accuracy.valid: 0.532135 58,395ms 0:16m/ 0:14m \n", + "Saving model...\n", + "Killing loop...\n", + "Still updating LabML App, please wait for it to complete..." + ], + "text/plain": [ + "
Updating App. Please wait..."
+ ],
+ "text/plain": [
+ "