diff --git a/labml_nn/rl/ppo/experiment.ipynb b/labml_nn/rl/ppo/experiment.ipynb new file mode 100644 index 00000000..29e4afca --- /dev/null +++ b/labml_nn/rl/ppo/experiment.ipynb @@ -0,0 +1,270 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Proximal Policy Optimization - PPO", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2" + }, + "source": [ + "[![Github](https://img.shields.io/github/stars/lab-ml/nn?style=social)](https://github.com/lab-ml/nn)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb) \n", + "\n", + "## Fast Weights Transformer\n", + "\n", + "This is an experiment training Shakespeare dataset with a Compressive Transformer model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9" + }, + "source": [ + "Install the `labml-nn` package" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZCzmCrAIVg0L", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee" + }, + "source": [ + "!pip install labml-nn" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting labml-nn\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a4/07/d33ead6f84fad2a4e8ff31ccd42864ff7b942785ad9f80d7c98df1c20a02/labml_nn-0.4.94-py3-none-any.whl (171kB)\n", + "\u001b[K |████████████████████████████████| 174kB 16.4MB/s \n", + "\u001b[?25hCollecting labml>=0.4.110\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/eb/c8/98b18d0dda3811998838734f9a32e944397fdd6bb0597cef0ae2b57338e3/labml-0.4.110-py3-none-any.whl (106kB)\n", + "\u001b[K |████████████████████████████████| 112kB 34.7MB/s \n", + "\u001b[?25hCollecting labml-helpers>=0.4.76\n", + " Downloading https://files.pythonhosted.org/packages/49/df/4d920a4a221acd3cfa384dddb909ed0691b08682c0d8aeaabeee2138624f/labml_helpers-0.4.76-py3-none-any.whl\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.19.5)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.8.0+cu101)\n", + "Collecting einops\n", + " Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", + "Collecting gitpython\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)\n", + "\u001b[K |████████████████████████████████| 163kB 47.1MB/s \n", + "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from labml>=0.4.110->labml-nn) (3.13)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->labml-nn) (3.7.4.3)\n", + "Collecting gitdb<5,>=4.0.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n", + "\u001b[K |████████████████████████████████| 71kB 11.1MB/s \n", + "\u001b[?25hCollecting smmap<5,>=3.0.1\n", + " Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n", + "Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n", + "Successfully installed einops-0.3.0 gitdb-4.0.7 gitpython-3.1.14 labml-0.4.110 labml-helpers-0.4.76 labml-nn-0.4.94 smmap-4.0.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0hJXx_g0wS2C" + }, + "source": [ + "from labml import experiment\n", + "from labml.configs import FloatDynamicHyperParam\n", + "from labml_nn.rl.ppo.experiment import Trainer" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lpggo0wM6qb-" + }, + "source": [ + "Create an experiment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bFcr9k-l4cAg" + }, + "source": [ + "experiment.create(name=\"ppo\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-OnHLi626tJt" + }, + "source": [ + "Configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Piz0c5f44hRo" + }, + "source": [ + "configs = {\n", + " # number of updates\n", + " 'updates': 10000,\n", + " # number of epochs to train the model with sampled data\n", + " 'epochs': 4,\n", + " # number of worker processes\n", + " 'n_workers': 8,\n", + " # number of steps to run on each process for a single update\n", + " 'worker_steps': 128,\n", + " # number of mini batches\n", + " 'batches': 4,\n", + " # Value loss coefficient\n", + " 'value_loss_coef': FloatDynamicHyperParam(0.5),\n", + " # Entropy bonus coefficient\n", + " 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),\n", + " # Clip range\n", + " 'clip_range': FloatDynamicHyperParam(0.1),\n", + " # Learning rate\n", + " 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n", + "}" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwMzCqpD6vkL" + }, + "source": [ + "Set experiment configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "e6hmQhTw4nks", + "outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de" + }, + "source": [ + "experiment.configs(configs)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
"
+            ],
+            "text/plain": [
+              ""
+            ]
+          },
+          "metadata": {
+            "tags": []
+          }
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "qYQCFt_JYsjd"
+      },
+      "source": [
+        "Create trainer"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "8LB7XVViYuPG"
+      },
+      "source": [
+        "trainer = Trainer(\n",
+        "    updates=configs['updates'],\n",
+        "    epochs=configs['epochs'],\n",
+        "    n_workers=configs['n_workers'],\n",
+        "    worker_steps=configs['worker_steps'],\n",
+        "    batches=configs['batches'],\n",
+        "    value_loss_coef=configs['value_loss_coef'],\n",
+        "    entropy_bonus_coef=configs['entropy_bonus_coef'],\n",
+        "    clip_range=configs['clip_range'],\n",
+        "    learning_rate=configs['learning_rate'],\n",
+        ")"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "KJZRf8527GxL"
+      },
+      "source": [
+        "Start the experiment and run the training loop."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "aIAWo7Fw5DR8"
+      },
+      "source": [
+        "with experiment.start():\n",
+        "    trainer.run_training_loop()"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "oBXXlP2b7XZO"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file