mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 18:27:03 +08:00
✨ ppo colab
This commit is contained in:
270
labml_nn/rl/ppo/experiment.ipynb
Normal file
270
labml_nn/rl/ppo/experiment.ipynb
Normal file
@ -0,0 +1,270 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Proximal Policy Optimization - PPO",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AYV_dMVDxyc2"
|
||||
},
|
||||
"source": [
|
||||
"[](https://github.com/lab-ml/nn)\n",
|
||||
"[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb) \n",
|
||||
"\n",
|
||||
"## Fast Weights Transformer\n",
|
||||
"\n",
|
||||
"This is an experiment training Shakespeare dataset with a Compressive Transformer model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AahG_i2y5tY9"
|
||||
},
|
||||
"source": [
|
||||
"Install the `labml-nn` package"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "ZCzmCrAIVg0L",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee"
|
||||
},
|
||||
"source": [
|
||||
"!pip install labml-nn"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting labml-nn\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a4/07/d33ead6f84fad2a4e8ff31ccd42864ff7b942785ad9f80d7c98df1c20a02/labml_nn-0.4.94-py3-none-any.whl (171kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 174kB 16.4MB/s \n",
|
||||
"\u001b[?25hCollecting labml>=0.4.110\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/eb/c8/98b18d0dda3811998838734f9a32e944397fdd6bb0597cef0ae2b57338e3/labml-0.4.110-py3-none-any.whl (106kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 112kB 34.7MB/s \n",
|
||||
"\u001b[?25hCollecting labml-helpers>=0.4.76\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/49/df/4d920a4a221acd3cfa384dddb909ed0691b08682c0d8aeaabeee2138624f/labml_helpers-0.4.76-py3-none-any.whl\n",
|
||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.19.5)\n",
|
||||
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.8.0+cu101)\n",
|
||||
"Collecting einops\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
|
||||
"Collecting gitpython\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 163kB 47.1MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from labml>=0.4.110->labml-nn) (3.13)\n",
|
||||
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
|
||||
"Collecting gitdb<5,>=4.0.1\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 71kB 11.1MB/s \n",
|
||||
"\u001b[?25hCollecting smmap<5,>=3.0.1\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n",
|
||||
"Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n",
|
||||
"Successfully installed einops-0.3.0 gitdb-4.0.7 gitpython-3.1.14 labml-0.4.110 labml-helpers-0.4.76 labml-nn-0.4.94 smmap-4.0.0\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SE2VUQ6L5zxI"
|
||||
},
|
||||
"source": [
|
||||
"Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "0hJXx_g0wS2C"
|
||||
},
|
||||
"source": [
|
||||
"from labml import experiment\n",
|
||||
"from labml.configs import FloatDynamicHyperParam\n",
|
||||
"from labml_nn.rl.ppo.experiment import Trainer"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Lpggo0wM6qb-"
|
||||
},
|
||||
"source": [
|
||||
"Create an experiment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "bFcr9k-l4cAg"
|
||||
},
|
||||
"source": [
|
||||
"experiment.create(name=\"ppo\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-OnHLi626tJt"
|
||||
},
|
||||
"source": [
|
||||
"Configurations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Piz0c5f44hRo"
|
||||
},
|
||||
"source": [
|
||||
"configs = {\n",
|
||||
" # number of updates\n",
|
||||
" 'updates': 10000,\n",
|
||||
" # number of epochs to train the model with sampled data\n",
|
||||
" 'epochs': 4,\n",
|
||||
" # number of worker processes\n",
|
||||
" 'n_workers': 8,\n",
|
||||
" # number of steps to run on each process for a single update\n",
|
||||
" 'worker_steps': 128,\n",
|
||||
" # number of mini batches\n",
|
||||
" 'batches': 4,\n",
|
||||
" # Value loss coefficient\n",
|
||||
" 'value_loss_coef': FloatDynamicHyperParam(0.5),\n",
|
||||
" # Entropy bonus coefficient\n",
|
||||
" 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),\n",
|
||||
" # Clip range\n",
|
||||
" 'clip_range': FloatDynamicHyperParam(0.1),\n",
|
||||
" # Learning rate\n",
|
||||
" 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n",
|
||||
"}"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wwMzCqpD6vkL"
|
||||
},
|
||||
"source": [
|
||||
"Set experiment configurations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 17
|
||||
},
|
||||
"id": "e6hmQhTw4nks",
|
||||
"outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de"
|
||||
},
|
||||
"source": [
|
||||
"experiment.configs(configs)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"overflow-x: scroll;\"></pre>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qYQCFt_JYsjd"
|
||||
},
|
||||
"source": [
|
||||
"Create trainer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "8LB7XVViYuPG"
|
||||
},
|
||||
"source": [
|
||||
"trainer = Trainer(\n",
|
||||
" updates=configs['updates'],\n",
|
||||
" epochs=configs['epochs'],\n",
|
||||
" n_workers=configs['n_workers'],\n",
|
||||
" worker_steps=configs['worker_steps'],\n",
|
||||
" batches=configs['batches'],\n",
|
||||
" value_loss_coef=configs['value_loss_coef'],\n",
|
||||
" entropy_bonus_coef=configs['entropy_bonus_coef'],\n",
|
||||
" clip_range=configs['clip_range'],\n",
|
||||
" learning_rate=configs['learning_rate'],\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "KJZRf8527GxL"
|
||||
},
|
||||
"source": [
|
||||
"Start the experiment and run the training loop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "aIAWo7Fw5DR8"
|
||||
},
|
||||
"source": [
|
||||
"with experiment.start():\n",
|
||||
" trainer.run_training_loop()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "oBXXlP2b7XZO"
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user