mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	🧪 colab experiment
This commit is contained in:
		| @ -29,7 +29,7 @@ Annotated implementation of relative multi-headed attention is in [`relative_mha | ||||
| Here's [the training code](experiment.html) and a notebook for training a transformer XL model on Tiny Shakespeare dataset. | ||||
|  | ||||
| [](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb) | ||||
| [](https://web.lab-ml.com/run?uuid=c4656c605b9311eba13d0242ac1c0002) | ||||
| [](https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002) | ||||
|  | ||||
| """ | ||||
|  | ||||
|  | ||||
							
								
								
									
										347
									
								
								labml_nn/transformers/xl/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										347
									
								
								labml_nn/transformers/xl/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,347 @@ | ||||
| { | ||||
|   "nbformat": 4, | ||||
|   "nbformat_minor": 0, | ||||
|   "metadata": { | ||||
|     "colab": { | ||||
|       "name": "Switch Transformer", | ||||
|       "provenance": [], | ||||
|       "collapsed_sections": [] | ||||
|     }, | ||||
|     "kernelspec": { | ||||
|       "name": "python3", | ||||
|       "display_name": "Python 3" | ||||
|     }, | ||||
|     "accelerator": "GPU" | ||||
|   }, | ||||
|   "cells": [ | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "AYV_dMVDxyc2" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "[](https://github.com/lab-ml/nn)\n", | ||||
|         "[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb)                    \n", | ||||
|         "\n", | ||||
|         "## Transformer XL\n", | ||||
|         "\n", | ||||
|         "This is an experiment training Shakespeare dataset with a Transformer XL model." | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "AahG_i2y5tY9" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Install the `labml-nn` package" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "ZCzmCrAIVg0L", | ||||
|         "colab": { | ||||
|           "base_uri": "https://localhost:8080/" | ||||
|         }, | ||||
|         "outputId": "1a4a59ce-b300-4d9f-baee-15720d696773" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "!pip install labml-nn" | ||||
|       ], | ||||
|       "execution_count": 1, | ||||
|       "outputs": [ | ||||
|         { | ||||
|           "output_type": "stream", | ||||
|           "text": [ | ||||
|             "Collecting labml-nn\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/97/eb/122e36372ca246eced93fbe380730c660cde63ac26dcb6440610bc8709dc/labml_nn-0.4.86-py3-none-any.whl (145kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 153kB 17.4MB/s \n", | ||||
|             "\u001b[?25hCollecting labml>=0.4.97\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/c6/e5/4f297474555a3e80b01c02861bdb90c7cd088235553563e8f3fb790bc5ee/labml-0.4.99-py3-none-any.whl (101kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 102kB 12.9MB/s \n", | ||||
|             "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.5)\n", | ||||
|             "Collecting labml-helpers>=0.4.74\n", | ||||
|             "  Downloading https://files.pythonhosted.org/packages/71/40/c0b73ed57edbb05d98b5abbdc82d852e5efe5739898215274b423b97bddc/labml_helpers-0.4.74-py3-none-any.whl\n", | ||||
|             "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n", | ||||
|             "Collecting einops\n", | ||||
|             "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", | ||||
|             "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from labml>=0.4.97->labml-nn) (3.13)\n", | ||||
|             "Collecting gitpython\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 163kB 58.7MB/s \n", | ||||
|             "\u001b[?25hRequirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n", | ||||
|             "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n", | ||||
|             "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", | ||||
|             "Collecting gitdb<5,>=4.0.1\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 71kB 8.2MB/s \n", | ||||
|             "\u001b[?25hCollecting smmap<4,>=3.0.1\n", | ||||
|             "  Downloading https://files.pythonhosted.org/packages/d5/1e/6130925131f639b2acde0f7f18b73e33ce082ff2d90783c436b52040af5a/smmap-3.0.5-py2.py3-none-any.whl\n", | ||||
|             "Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n", | ||||
|             "Successfully installed einops-0.3.0 gitdb-4.0.5 gitpython-3.1.12 labml-0.4.99 labml-helpers-0.4.74 labml-nn-0.4.86 smmap-3.0.5\n" | ||||
|           ], | ||||
|           "name": "stdout" | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "SE2VUQ6L5zxI" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Imports" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "0hJXx_g0wS2C" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "import torch\n", | ||||
|         "import torch.nn as nn\n", | ||||
|         "\n", | ||||
|         "from labml import experiment\n", | ||||
|         "from labml.configs import option\n", | ||||
|         "from labml_helpers.module import Module\n", | ||||
|         "from labml_nn.transformers.xl.experiment import Configs" | ||||
|       ], | ||||
|       "execution_count": 2, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "Lpggo0wM6qb-" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Create an experiment" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "bFcr9k-l4cAg" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "experiment.create(name=\"transformer_xl\")" | ||||
|       ], | ||||
|       "execution_count": 3, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "-OnHLi626tJt" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Initialize configurations" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "Piz0c5f44hRo" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "conf = Configs()" | ||||
|       ], | ||||
|       "execution_count": 4, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "wwMzCqpD6vkL" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Set experiment configurations and assign a configurations dictionary to override configurations" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "colab": { | ||||
|           "base_uri": "https://localhost:8080/", | ||||
|           "height": 17 | ||||
|         }, | ||||
|         "id": "e6hmQhTw4nks", | ||||
|         "outputId": "839820be-f5a9-476d-d458-7d2ff3278e89" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "experiment.configs(conf,\n", | ||||
|         "                # A dictionary of configurations to override\n", | ||||
|         "                {'tokenizer': 'character',\n", | ||||
|         "                'text': 'tiny_shakespeare',\n", | ||||
|         "                'optimizer.learning_rate': 1.,\n", | ||||
|         "                'optimizer.optimizer': 'Noam',\n", | ||||
|         "                'prompt': 'It is',\n", | ||||
|         "                'prompt_separator': '',\n", | ||||
|         "\n", | ||||
|         "                'train_loader': 'sequential_train_loader',\n", | ||||
|         "                'valid_loader': 'sequential_valid_loader',\n", | ||||
|         "\n", | ||||
|         "                'seq_len': 2,\n", | ||||
|         "                'mem_len': 32,\n", | ||||
|         "                'epochs': 128,\n", | ||||
|         "                'batch_size': 32,\n", | ||||
|         "                'inner_iterations': 25,\n", | ||||
|         "                })" | ||||
|       ], | ||||
|       "execution_count": 5, | ||||
|       "outputs": [ | ||||
|         { | ||||
|           "output_type": "display_data", | ||||
|           "data": { | ||||
|             "text/html": [ | ||||
|               "<pre style=\"overflow-x: scroll;\"></pre>" | ||||
|             ], | ||||
|             "text/plain": [ | ||||
|               "<IPython.core.display.HTML object>" | ||||
|             ] | ||||
|           }, | ||||
|           "metadata": { | ||||
|             "tags": [] | ||||
|           } | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "EvI7MtgJ61w5" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Set PyTorch models for loading and saving" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "colab": { | ||||
|           "base_uri": "https://localhost:8080/", | ||||
|           "height": 255 | ||||
|         }, | ||||
|         "id": "GDlt7dp-5ALt", | ||||
|         "outputId": "0543a726-fcf4-4493-dbe9-ab62c5bea94f" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "experiment.add_pytorch_models({'model': conf.model})" | ||||
|       ], | ||||
|       "execution_count": 6, | ||||
|       "outputs": [ | ||||
|         { | ||||
|           "output_type": "display_data", | ||||
|           "data": { | ||||
|             "text/html": [ | ||||
|               "<pre style=\"overflow-x: scroll;\">Prepare model...\n", | ||||
|               "  Prepare n_tokens...\n", | ||||
|               "    Prepare text...\n", | ||||
|               "      Prepare tokenizer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3.81ms</span>\n", | ||||
|               "      Download<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t133.10ms</span>\n", | ||||
|               "      Load data<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t7.14ms</span>\n", | ||||
|               "      Tokenize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t28.06ms</span>\n", | ||||
|               "      Build vocabulary<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t139.53ms</span>\n", | ||||
|               "    Prepare text<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t327.93ms</span>\n", | ||||
|               "  Prepare n_tokens<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t338.87ms</span>\n", | ||||
|               "  Prepare device...\n", | ||||
|               "    Prepare device_info<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t61.05ms</span>\n", | ||||
|               "  Prepare device<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t68.43ms</span>\n", | ||||
|               "Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t10,682.78ms</span>\n", | ||||
|               "</pre>" | ||||
|             ], | ||||
|             "text/plain": [ | ||||
|               "<IPython.core.display.HTML object>" | ||||
|             ] | ||||
|           }, | ||||
|           "metadata": { | ||||
|             "tags": [] | ||||
|           } | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "KJZRf8527GxL" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Start the experiment and run the training loop." | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "colab": { | ||||
|           "base_uri": "https://localhost:8080/", | ||||
|           "height": 493 | ||||
|         }, | ||||
|         "id": "aIAWo7Fw5DR8", | ||||
|         "outputId": "64764a3f-e8d7-4c06-b803-0da6ff756163" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "# Start the experiment\n", | ||||
|         "with experiment.start():\n", | ||||
|         "    conf.run()" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [ | ||||
|         { | ||||
|           "output_type": "display_data", | ||||
|           "data": { | ||||
|             "text/html": [ | ||||
|               "<pre style=\"overflow-x: scroll;\">\n", | ||||
|               "<strong><span style=\"text-decoration: underline\">transformer_xl</span></strong>: <span style=\"color: #208FFB\">d3b6760c692e11ebb6a70242ac1c0002</span>\n", | ||||
|               "\t[dirty]: <strong><span style=\"color: #DDB62B\">\"\"</span></strong>\n", | ||||
|               "Initialize...\n", | ||||
|               "  Prepare mode<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t6.08ms</span>\n", | ||||
|               "Initialize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t91.71ms</span>\n", | ||||
|               "Prepare validator...\n", | ||||
|               "  Prepare valid_loader<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t72.62ms</span>\n", | ||||
|               "<span style=\"color: #C5C1B4\"></span>\n", | ||||
|               "<span style=\"color: #C5C1B4\">--------------------------------------------------</span><span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>\n", | ||||
|               "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\">LABML WARNING</span></strong></span>\n", | ||||
|               "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>LabML App Warning: <span style=\"color: #60C6C8\">empty_token: </span><strong>Please create a valid token at https://web.lab-ml.com.</strong>\n", | ||||
|               "<strong>Click on the experiment link to monitor the experiment and add it to your experiments list.</strong><span style=\"color: #C5C1B4\"></span>\n", | ||||
|               "<span style=\"color: #C5C1B4\">--------------------------------------------------</span>\n", | ||||
|               "<span style=\"color: #208FFB\">Monitor experiment at </span><a href='https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002' target='blank'>https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002</a>\n", | ||||
|               "Prepare validator<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t159.97ms</span>\n", | ||||
|               "Prepare trainer...\n", | ||||
|               "  Prepare train_loader<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t105.71ms</span>\n", | ||||
|               "Prepare trainer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t138.48ms</span>\n", | ||||
|               "Prepare training_loop...\n", | ||||
|               "  Prepare loop_count<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t34.64ms</span>\n", | ||||
|               "Prepare training_loop<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t276.11ms</span>\n", | ||||
|               "<span style=\"color: #C5C1B4\">It is</span><strong>m</strong><strong>D</strong><strong>$</strong><strong>I</strong><strong>m</strong><strong>?</strong><strong>P</strong><strong>h</strong><strong>g</strong><strong>p</strong><strong>Q</strong><strong>x</strong><strong>P</strong><strong>P</strong><strong>P</strong><strong>P</strong><strong>,</strong><strong>o</strong><strong>F</strong><strong>r</strong><strong>F</strong><strong>r</strong><strong>F</strong><strong>r</strong><strong>F</strong>\n", | ||||
|               "<span style=\"color: #C5C1B4\">It is</span><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong>\n", | ||||
|               "<span style=\"color: #C5C1B4\">It is</span><strong>e</strong><strong>e</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>a</strong><strong>t</strong>\n", | ||||
|               "<span style=\"color: #C5C1B4\">It is</span><strong>e</strong><strong>d</strong><strong>,</strong><strong></strong>\n", | ||||
|               "<strong></strong><strong>T</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>w</strong><strong>h</strong><strong>i</strong><strong>s</strong><strong> </strong><strong>w</strong><strong>i</strong><strong>t</strong><strong>h</strong><strong> </strong><strong>w</strong><strong>e</strong><strong> </strong><strong>w</strong><strong>i</strong><strong>t</strong><strong>h</strong>\n", | ||||
|               "<span style=\"color: #C5C1B4\">It is</span><strong>e</strong><strong>n</strong><strong>t</strong><strong>e</strong><strong>r</strong><strong> </strong><strong>h</strong><strong>e</strong><strong>a</strong><strong>v</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>m</strong><strong>e</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>a</strong><strong>n</strong>\n", | ||||
|               "<strong><span style=\"color: #DDB62B\"> 200,768:  </span></strong>Sample:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">   393ms  </span>Train:<span style=\"color: #C5C1B4\">  20%</span><span style=\"color: #208FFB\">  1,161,309ms  </span>Valid:<span style=\"color: #C5C1B4\">  16%</span><span style=\"color: #208FFB\"> 32,270ms  </span> loss.train: <span style=\"color: #C5C1B4\"> 2.12580</span> accuracy.train: <span style=\"color: #C5C1B4\">0.335542</span> loss.valid: <strong> 2.23436</strong> accuracy.valid: <strong>0.326474</strong>  <span style=\"color: #208FFB\">1,193,972ms</span><span style=\"color: #D160C4\">  0:04m/ 42:23m  </span></pre>" | ||||
|             ], | ||||
|             "text/plain": [ | ||||
|               "<IPython.core.display.HTML object>" | ||||
|             ] | ||||
|           }, | ||||
|           "metadata": { | ||||
|             "tags": [] | ||||
|           } | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "oBXXlP2b7XZO" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri