mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	✨ ppo colab
This commit is contained in:
		
							
								
								
									
										270
									
								
								labml_nn/rl/ppo/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								labml_nn/rl/ppo/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,270 @@ | ||||
| { | ||||
|   "nbformat": 4, | ||||
|   "nbformat_minor": 0, | ||||
|   "metadata": { | ||||
|     "colab": { | ||||
|       "name": "Proximal Policy Optimization - PPO", | ||||
|       "provenance": [], | ||||
|       "collapsed_sections": [] | ||||
|     }, | ||||
|     "kernelspec": { | ||||
|       "name": "python3", | ||||
|       "display_name": "Python 3" | ||||
|     }, | ||||
|     "accelerator": "GPU" | ||||
|   }, | ||||
|   "cells": [ | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "AYV_dMVDxyc2" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "[](https://github.com/lab-ml/nn)\n", | ||||
|         "[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb)                    \n", | ||||
|         "\n", | ||||
|         "## Fast Weights Transformer\n", | ||||
|         "\n", | ||||
|         "This is an experiment training Shakespeare dataset with a Compressive Transformer model." | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "AahG_i2y5tY9" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Install the `labml-nn` package" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "ZCzmCrAIVg0L", | ||||
|         "colab": { | ||||
|           "base_uri": "https://localhost:8080/" | ||||
|         }, | ||||
|         "outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "!pip install labml-nn" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [ | ||||
|         { | ||||
|           "output_type": "stream", | ||||
|           "text": [ | ||||
|             "Collecting labml-nn\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a4/07/d33ead6f84fad2a4e8ff31ccd42864ff7b942785ad9f80d7c98df1c20a02/labml_nn-0.4.94-py3-none-any.whl (171kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 174kB 16.4MB/s \n", | ||||
|             "\u001b[?25hCollecting labml>=0.4.110\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/eb/c8/98b18d0dda3811998838734f9a32e944397fdd6bb0597cef0ae2b57338e3/labml-0.4.110-py3-none-any.whl (106kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 112kB 34.7MB/s \n", | ||||
|             "\u001b[?25hCollecting labml-helpers>=0.4.76\n", | ||||
|             "  Downloading https://files.pythonhosted.org/packages/49/df/4d920a4a221acd3cfa384dddb909ed0691b08682c0d8aeaabeee2138624f/labml_helpers-0.4.76-py3-none-any.whl\n", | ||||
|             "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.19.5)\n", | ||||
|             "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.8.0+cu101)\n", | ||||
|             "Collecting einops\n", | ||||
|             "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", | ||||
|             "Collecting gitpython\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 163kB 47.1MB/s \n", | ||||
|             "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from labml>=0.4.110->labml-nn) (3.13)\n", | ||||
|             "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->labml-nn) (3.7.4.3)\n", | ||||
|             "Collecting gitdb<5,>=4.0.1\n", | ||||
|             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n", | ||||
|             "\u001b[K     |████████████████████████████████| 71kB 11.1MB/s \n", | ||||
|             "\u001b[?25hCollecting smmap<5,>=3.0.1\n", | ||||
|             "  Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n", | ||||
|             "Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n", | ||||
|             "Successfully installed einops-0.3.0 gitdb-4.0.7 gitpython-3.1.14 labml-0.4.110 labml-helpers-0.4.76 labml-nn-0.4.94 smmap-4.0.0\n" | ||||
|           ], | ||||
|           "name": "stdout" | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "SE2VUQ6L5zxI" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Imports" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "0hJXx_g0wS2C" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "from labml import experiment\n", | ||||
|         "from labml.configs import FloatDynamicHyperParam\n", | ||||
|         "from labml_nn.rl.ppo.experiment import Trainer" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "Lpggo0wM6qb-" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Create an experiment" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "bFcr9k-l4cAg" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "experiment.create(name=\"ppo\")" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "-OnHLi626tJt" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Configurations" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "Piz0c5f44hRo" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "configs = {\n", | ||||
|         "    # number of updates\n", | ||||
|         "    'updates': 10000,\n", | ||||
|         "    # number of epochs to train the model with sampled data\n", | ||||
|         "    'epochs': 4,\n", | ||||
|         "    # number of worker processes\n", | ||||
|         "    'n_workers': 8,\n", | ||||
|         "    # number of steps to run on each process for a single update\n", | ||||
|         "    'worker_steps': 128,\n", | ||||
|         "    # number of mini batches\n", | ||||
|         "    'batches': 4,\n", | ||||
|         "    # Value loss coefficient\n", | ||||
|         "    'value_loss_coef': FloatDynamicHyperParam(0.5),\n", | ||||
|         "    # Entropy bonus coefficient\n", | ||||
|         "    'entropy_bonus_coef': FloatDynamicHyperParam(0.01),\n", | ||||
|         "    # Clip range\n", | ||||
|         "    'clip_range': FloatDynamicHyperParam(0.1),\n", | ||||
|         "    # Learning rate\n", | ||||
|         "    'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n", | ||||
|         "}" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "wwMzCqpD6vkL" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Set experiment configurations" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "colab": { | ||||
|           "base_uri": "https://localhost:8080/", | ||||
|           "height": 17 | ||||
|         }, | ||||
|         "id": "e6hmQhTw4nks", | ||||
|         "outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "experiment.configs(configs)" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [ | ||||
|         { | ||||
|           "output_type": "display_data", | ||||
|           "data": { | ||||
|             "text/html": [ | ||||
|               "<pre style=\"overflow-x: scroll;\"></pre>" | ||||
|             ], | ||||
|             "text/plain": [ | ||||
|               "<IPython.core.display.HTML object>" | ||||
|             ] | ||||
|           }, | ||||
|           "metadata": { | ||||
|             "tags": [] | ||||
|           } | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "qYQCFt_JYsjd" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Create trainer" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "8LB7XVViYuPG" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "trainer = Trainer(\n", | ||||
|         "    updates=configs['updates'],\n", | ||||
|         "    epochs=configs['epochs'],\n", | ||||
|         "    n_workers=configs['n_workers'],\n", | ||||
|         "    worker_steps=configs['worker_steps'],\n", | ||||
|         "    batches=configs['batches'],\n", | ||||
|         "    value_loss_coef=configs['value_loss_coef'],\n", | ||||
|         "    entropy_bonus_coef=configs['entropy_bonus_coef'],\n", | ||||
|         "    clip_range=configs['clip_range'],\n", | ||||
|         "    learning_rate=configs['learning_rate'],\n", | ||||
|         ")" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "markdown", | ||||
|       "metadata": { | ||||
|         "id": "KJZRf8527GxL" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "Start the experiment and run the training loop." | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "aIAWo7Fw5DR8" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "with experiment.start():\n", | ||||
|         "    trainer.run_training_loop()" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     }, | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "metadata": { | ||||
|         "id": "oBXXlP2b7XZO" | ||||
|       }, | ||||
|       "source": [ | ||||
|         "" | ||||
|       ], | ||||
|       "execution_count": null, | ||||
|       "outputs": [] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri