mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	✨ ppo colab
This commit is contained in:
		
							
								
								
									
										270
									
								
								labml_nn/rl/ppo/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								labml_nn/rl/ppo/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,270 @@ | |||||||
|  | { | ||||||
|  |   "nbformat": 4, | ||||||
|  |   "nbformat_minor": 0, | ||||||
|  |   "metadata": { | ||||||
|  |     "colab": { | ||||||
|  |       "name": "Proximal Policy Optimization - PPO", | ||||||
|  |       "provenance": [], | ||||||
|  |       "collapsed_sections": [] | ||||||
|  |     }, | ||||||
|  |     "kernelspec": { | ||||||
|  |       "name": "python3", | ||||||
|  |       "display_name": "Python 3" | ||||||
|  |     }, | ||||||
|  |     "accelerator": "GPU" | ||||||
|  |   }, | ||||||
|  |   "cells": [ | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "AYV_dMVDxyc2" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "[](https://github.com/lab-ml/nn)\n", | ||||||
|  |         "[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb)                    \n", | ||||||
|  |         "\n", | ||||||
|  |         "## Fast Weights Transformer\n", | ||||||
|  |         "\n", | ||||||
|  |         "This is an experiment training Shakespeare dataset with a Compressive Transformer model." | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "AahG_i2y5tY9" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Install the `labml-nn` package" | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "ZCzmCrAIVg0L", | ||||||
|  |         "colab": { | ||||||
|  |           "base_uri": "https://localhost:8080/" | ||||||
|  |         }, | ||||||
|  |         "outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "!pip install labml-nn" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [ | ||||||
|  |         { | ||||||
|  |           "output_type": "stream", | ||||||
|  |           "text": [ | ||||||
|  |             "Collecting labml-nn\n", | ||||||
|  |             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a4/07/d33ead6f84fad2a4e8ff31ccd42864ff7b942785ad9f80d7c98df1c20a02/labml_nn-0.4.94-py3-none-any.whl (171kB)\n", | ||||||
|  |             "\u001b[K     |████████████████████████████████| 174kB 16.4MB/s \n", | ||||||
|  |             "\u001b[?25hCollecting labml>=0.4.110\n", | ||||||
|  |             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/eb/c8/98b18d0dda3811998838734f9a32e944397fdd6bb0597cef0ae2b57338e3/labml-0.4.110-py3-none-any.whl (106kB)\n", | ||||||
|  |             "\u001b[K     |████████████████████████████████| 112kB 34.7MB/s \n", | ||||||
|  |             "\u001b[?25hCollecting labml-helpers>=0.4.76\n", | ||||||
|  |             "  Downloading https://files.pythonhosted.org/packages/49/df/4d920a4a221acd3cfa384dddb909ed0691b08682c0d8aeaabeee2138624f/labml_helpers-0.4.76-py3-none-any.whl\n", | ||||||
|  |             "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.19.5)\n", | ||||||
|  |             "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.8.0+cu101)\n", | ||||||
|  |             "Collecting einops\n", | ||||||
|  |             "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", | ||||||
|  |             "Collecting gitpython\n", | ||||||
|  |             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)\n", | ||||||
|  |             "\u001b[K     |████████████████████████████████| 163kB 47.1MB/s \n", | ||||||
|  |             "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from labml>=0.4.110->labml-nn) (3.13)\n", | ||||||
|  |             "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->labml-nn) (3.7.4.3)\n", | ||||||
|  |             "Collecting gitdb<5,>=4.0.1\n", | ||||||
|  |             "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n", | ||||||
|  |             "\u001b[K     |████████████████████████████████| 71kB 11.1MB/s \n", | ||||||
|  |             "\u001b[?25hCollecting smmap<5,>=3.0.1\n", | ||||||
|  |             "  Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n", | ||||||
|  |             "Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n", | ||||||
|  |             "Successfully installed einops-0.3.0 gitdb-4.0.7 gitpython-3.1.14 labml-0.4.110 labml-helpers-0.4.76 labml-nn-0.4.94 smmap-4.0.0\n" | ||||||
|  |           ], | ||||||
|  |           "name": "stdout" | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "SE2VUQ6L5zxI" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Imports" | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "0hJXx_g0wS2C" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "from labml import experiment\n", | ||||||
|  |         "from labml.configs import FloatDynamicHyperParam\n", | ||||||
|  |         "from labml_nn.rl.ppo.experiment import Trainer" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "Lpggo0wM6qb-" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Create an experiment" | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "bFcr9k-l4cAg" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "experiment.create(name=\"ppo\")" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "-OnHLi626tJt" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Configurations" | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "Piz0c5f44hRo" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "configs = {\n", | ||||||
|  |         "    # number of updates\n", | ||||||
|  |         "    'updates': 10000,\n", | ||||||
|  |         "    # number of epochs to train the model with sampled data\n", | ||||||
|  |         "    'epochs': 4,\n", | ||||||
|  |         "    # number of worker processes\n", | ||||||
|  |         "    'n_workers': 8,\n", | ||||||
|  |         "    # number of steps to run on each process for a single update\n", | ||||||
|  |         "    'worker_steps': 128,\n", | ||||||
|  |         "    # number of mini batches\n", | ||||||
|  |         "    'batches': 4,\n", | ||||||
|  |         "    # Value loss coefficient\n", | ||||||
|  |         "    'value_loss_coef': FloatDynamicHyperParam(0.5),\n", | ||||||
|  |         "    # Entropy bonus coefficient\n", | ||||||
|  |         "    'entropy_bonus_coef': FloatDynamicHyperParam(0.01),\n", | ||||||
|  |         "    # Clip range\n", | ||||||
|  |         "    'clip_range': FloatDynamicHyperParam(0.1),\n", | ||||||
|  |         "    # Learning rate\n", | ||||||
|  |         "    'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n", | ||||||
|  |         "}" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "wwMzCqpD6vkL" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Set experiment configurations" | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "colab": { | ||||||
|  |           "base_uri": "https://localhost:8080/", | ||||||
|  |           "height": 17 | ||||||
|  |         }, | ||||||
|  |         "id": "e6hmQhTw4nks", | ||||||
|  |         "outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "experiment.configs(configs)" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [ | ||||||
|  |         { | ||||||
|  |           "output_type": "display_data", | ||||||
|  |           "data": { | ||||||
|  |             "text/html": [ | ||||||
|  |               "<pre style=\"overflow-x: scroll;\"></pre>" | ||||||
|  |             ], | ||||||
|  |             "text/plain": [ | ||||||
|  |               "<IPython.core.display.HTML object>" | ||||||
|  |             ] | ||||||
|  |           }, | ||||||
|  |           "metadata": { | ||||||
|  |             "tags": [] | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "qYQCFt_JYsjd" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Create trainer" | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "8LB7XVViYuPG" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "trainer = Trainer(\n", | ||||||
|  |         "    updates=configs['updates'],\n", | ||||||
|  |         "    epochs=configs['epochs'],\n", | ||||||
|  |         "    n_workers=configs['n_workers'],\n", | ||||||
|  |         "    worker_steps=configs['worker_steps'],\n", | ||||||
|  |         "    batches=configs['batches'],\n", | ||||||
|  |         "    value_loss_coef=configs['value_loss_coef'],\n", | ||||||
|  |         "    entropy_bonus_coef=configs['entropy_bonus_coef'],\n", | ||||||
|  |         "    clip_range=configs['clip_range'],\n", | ||||||
|  |         "    learning_rate=configs['learning_rate'],\n", | ||||||
|  |         ")" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "markdown", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "KJZRf8527GxL" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "Start the experiment and run the training loop." | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "aIAWo7Fw5DR8" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "with experiment.start():\n", | ||||||
|  |         "    trainer.run_training_loop()" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "cell_type": "code", | ||||||
|  |       "metadata": { | ||||||
|  |         "id": "oBXXlP2b7XZO" | ||||||
|  |       }, | ||||||
|  |       "source": [ | ||||||
|  |         "" | ||||||
|  |       ], | ||||||
|  |       "execution_count": null, | ||||||
|  |       "outputs": [] | ||||||
|  |     } | ||||||
|  |   ] | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri