mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			204 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			204 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|   "nbformat": 4,
 | |
|   "nbformat_minor": 0,
 | |
|   "metadata": {
 | |
|     "colab": {
 | |
|       "name": "HyperLSTM",
 | |
|       "provenance": [],
 | |
|       "collapsed_sections": []
 | |
|     },
 | |
|     "kernelspec": {
 | |
|       "name": "python3",
 | |
|       "display_name": "Python 3"
 | |
|     },
 | |
|     "accelerator": "GPU"
 | |
|   },
 | |
|   "cells": [
 | |
|     {
 | |
|       "cell_type": "markdown",
 | |
|       "metadata": {
 | |
|         "id": "AYV_dMVDxyc2"
 | |
|       },
 | |
|       "source": [
 | |
|         "[](https://github.com/lab-ml/nn)\n",
 | |
|         "[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/hypernetworks/experiment.ipynb)                    \n",
 | |
|         "\n",
 | |
|         "## HyperLSTM\n",
 | |
|         "\n",
 | |
|         "This is an experiment training Shakespear dataset with HyperLSTM from paper HyperNetworks."
 | |
|       ]
 | |
|     },
 | |
|     {
 | |
|       "cell_type": "code",
 | |
|       "metadata": {
 | |
|         "id": "ZCzmCrAIVg0L"
 | |
|       },
 | |
|       "source": [
 | |
|         "!pip install labml-nn"
 | |
|       ],
 | |
|       "execution_count": null,
 | |
|       "outputs": []
 | |
|     },
 | |
|     {
 | |
|       "cell_type": "code",
 | |
|       "metadata": {
 | |
|         "id": "0hJXx_g0wS2C"
 | |
|       },
 | |
|       "source": [
 | |
|         "from labml import experiment\n",
 | |
|         "from labml_nn.hypernetworks.experiment import Configs"
 | |
|       ],
 | |
|       "execution_count": 3,
 | |
|       "outputs": []
 | |
|     },
 | |
|     {
 | |
|       "cell_type": "code",
 | |
|       "metadata": {
 | |
|         "colab": {
 | |
|           "base_uri": "https://localhost:8080/",
 | |
|           "height": 255
 | |
|         },
 | |
|         "id": "WQ8VGpMGwZuj",
 | |
|         "outputId": "5833cc50-26a8-496e-e729-88f42b3f4651"
 | |
|       },
 | |
|       "source": [
 | |
|         "# Create experiment\n",
 | |
|         "experiment.create(name=\"hyper_lstm\", comment='')\n",
 | |
|         "# Create configs\n",
 | |
|         "conf = Configs()\n",
 | |
|         "# Load configurations\n",
 | |
|         "experiment.configs(conf,\n",
 | |
|         "                    # A dictionary of configurations to override\n",
 | |
|         "                    {'tokenizer': 'character',\n",
 | |
|         "                    'text': 'tiny_shakespeare',\n",
 | |
|         "                    'optimizer.learning_rate': 2.5e-4,\n",
 | |
|         "                    'optimizer.optimizer': 'Adam',\n",
 | |
|         "                    'prompt': 'It is',\n",
 | |
|         "                    'prompt_separator': '',\n",
 | |
|         "\n",
 | |
|         "                    'rnn_model': 'hyper_lstm',\n",
 | |
|         "\n",
 | |
|         "                    'train_loader': 'shuffled_train_loader',\n",
 | |
|         "                    'valid_loader': 'shuffled_valid_loader',\n",
 | |
|         "\n",
 | |
|         "                    'seq_len': 512,\n",
 | |
|         "                    'epochs': 128,\n",
 | |
|         "                    'batch_size': 2,\n",
 | |
|         "                    'inner_iterations': 25})\n",
 | |
|         "\n",
 | |
|         "\n",
 | |
|         "# Set models for saving and loading\n",
 | |
|         "experiment.add_pytorch_models({'model': conf.model})\n",
 | |
|         "\n",
 | |
|         "conf.init()"
 | |
|       ],
 | |
|       "execution_count": 5,
 | |
|       "outputs": [
 | |
|         {
 | |
|           "output_type": "display_data",
 | |
|           "data": {
 | |
|             "text/html": [
 | |
|               "<pre style=\"overflow-x: scroll;\">\n",
 | |
|               "Prepare model...\n",
 | |
|               "  Prepare n_tokens...\n",
 | |
|               "    Prepare text...\n",
 | |
|               "      Prepare tokenizer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3.07ms</span>\n",
 | |
|               "      Load data<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t2.85ms</span>\n",
 | |
|               "      Tokenize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t33.69ms</span>\n",
 | |
|               "      Build vocabulary<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t103.52ms</span>\n",
 | |
|               "    Prepare text<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t153.38ms</span>\n",
 | |
|               "  Prepare n_tokens<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t160.21ms</span>\n",
 | |
|               "  Prepare rnn_model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t13.84ms</span>\n",
 | |
|               "Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t195.08ms</span>\n",
 | |
|               "Prepare mode<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.78ms</span>\n",
 | |
|               "</pre>"
 | |
|             ],
 | |
|             "text/plain": [
 | |
|               "<IPython.core.display.HTML object>"
 | |
|             ]
 | |
|           },
 | |
|           "metadata": {
 | |
|             "tags": []
 | |
|           }
 | |
|         },
 | |
|         {
 | |
|           "output_type": "stream",
 | |
|           "text": [
 | |
|             "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py:434: UserWarning: Setting attributes on ParameterList is not supported.\n",
 | |
|             "  warnings.warn(\"Setting attributes on ParameterList is not supported.\")\n"
 | |
|           ],
 | |
|           "name": "stderr"
 | |
|         }
 | |
|       ]
 | |
|     },
 | |
|     {
 | |
|       "cell_type": "code",
 | |
|       "metadata": {
 | |
|         "colab": {
 | |
|           "base_uri": "https://localhost:8080/",
 | |
|           "height": 425
 | |
|         },
 | |
|         "id": "f07vAOaHwumr",
 | |
|         "outputId": "6b51205e-3852-4dce-f7a7-f3ba4066ba21"
 | |
|       },
 | |
|       "source": [
 | |
|         "# Start the experiment\n",
 | |
|         "with experiment.start():\n",
 | |
|         "    # `TrainValidConfigs.run`\n",
 | |
|         "    conf.run()"
 | |
|       ],
 | |
|       "execution_count": null,
 | |
|       "outputs": [
 | |
|         {
 | |
|           "output_type": "display_data",
 | |
|           "data": {
 | |
|             "text/html": [
 | |
|               "<pre style=\"overflow-x: scroll;\">\n",
 | |
|               "<strong><span style=\"text-decoration: underline\">hyper_lstm</span></strong>: <span style=\"color: #208FFB\">5004f5724d8611eba84a0242ac1c0002</span>\n",
 | |
|               "\t[dirty]: <strong><span style=\"color: #DDB62B\">\"\"</span></strong>\n",
 | |
|               "Initialize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.12ms</span>\n",
 | |
|               "Prepare validator...\n",
 | |
|               "  Prepare valid_loader<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t76.72ms</span>\n",
 | |
|               "<span style=\"color: #C5C1B4\"></span>\n",
 | |
|               "<span style=\"color: #C5C1B4\">--------------------------------------------------</span><span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>\n",
 | |
|               "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\">LABML WARNING</span></strong></span>\n",
 | |
|               "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>LabML App Warning: <span style=\"color: #60C6C8\">empty_token: </span><strong>Please create a valid token at https://web.lab-ml.com.</strong>\n",
 | |
|               "<strong>Click on the experiment link to monitor the experiment and add it to your experiments list.</strong><span style=\"color: #C5C1B4\"></span>\n",
 | |
|               "<span style=\"color: #C5C1B4\">--------------------------------------------------</span>\n",
 | |
|               "<span style=\"color: #208FFB\">Monitor experiment at </span><a href='https://web.lab-ml.com/run?uuid=5004f5724d8611eba84a0242ac1c0002' target='blank'>https://web.lab-ml.com/run?uuid=5004f5724d8611eba84a0242ac1c0002</a>\n",
 | |
|               "Prepare validator<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t174.93ms</span>\n",
 | |
|               "Prepare trainer...\n",
 | |
|               "  Prepare train_loader<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t100.16ms</span>\n",
 | |
|               "Prepare trainer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t137.49ms</span>\n",
 | |
|               "Prepare training_loop...\n",
 | |
|               "  Prepare loop_count<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t37.12ms</span>\n",
 | |
|               "Prepare training_loop<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t301.04ms</span>\n",
 | |
|               "<span style=\"color: #C5C1B4\">It is</span><strong>?</strong><strong>?</strong><strong>?</strong><strong>?</strong><strong>?</strong><strong>?</strong><strong>?</strong><strong>?</strong><strong>n</strong><strong>n</strong><strong>?</strong><strong>n</strong><strong>?</strong><strong>n</strong><strong>?</strong><strong>n</strong><strong>?</strong><strong>?</strong><strong>?</strong><strong>n</strong><strong>n</strong><strong>?</strong><strong>n</strong><strong>?</strong><strong>n</strong>\n",
 | |
|               "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>a</strong><strong>n</strong><strong>d</strong><strong> </strong>\n",
 | |
|               "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong>\n",
 | |
|               "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong>\n",
 | |
|               "<strong><span style=\"color: #DDB62B\">  65,536:  </span></strong>Sample:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 1,288ms  </span>Train:<span style=\"color: #C5C1B4\">  13%</span><span style=\"color: #208FFB\"> 4,212,862ms  </span>Valid:<span style=\"color: #C5C1B4\">  11%</span><span style=\"color: #208FFB\"> 132,056ms  </span> accuracy.train: <strong>0.301926</strong> loss.train: <strong> 2.25940</strong> accuracy.valid: <span style=\"color: #C5C1B4\">0.330679</span> loss.valid: <span style=\"color: #C5C1B4\"> 2.48882</span>  <span style=\"color: #208FFB\">4,346,206ms</span><span style=\"color: #D160C4\">  0:08m/154:23m  </span></pre>"
 | |
|             ],
 | |
|             "text/plain": [
 | |
|               "<IPython.core.display.HTML object>"
 | |
|             ]
 | |
|           },
 | |
|           "metadata": {
 | |
|             "tags": []
 | |
|           }
 | |
|         }
 | |
|       ]
 | |
|     },
 | |
|     {
 | |
|       "cell_type": "code",
 | |
|       "metadata": {
 | |
|         "id": "crH6MzKmw-SY"
 | |
|       },
 | |
|       "source": [
 | |
|         ""
 | |
|       ],
 | |
|       "execution_count": null,
 | |
|       "outputs": []
 | |
|     }
 | |
|   ]
 | |
| } | 
