This commit is contained in:
Varuna Jayasiri
2021-03-30 11:37:06 +05:30
parent 6fba4c0957
commit ac40d0a7c9
2 changed files with 228 additions and 271 deletions

View File

@ -1,18 +1,4 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Proximal Policy Optimization - PPO",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -21,11 +7,11 @@
}, },
"source": [ "source": [
"[![Github](https://img.shields.io/github/stars/lab-ml/nn?style=social)](https://github.com/lab-ml/nn)\n", "[![Github](https://img.shields.io/github/stars/lab-ml/nn?style=social)](https://github.com/lab-ml/nn)\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb) \n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb) \n",
"\n", "\n",
"## Fast Weights Transformer\n", "## Proximal Policy Optimization - PPO\n",
"\n", "\n",
"This is an experiment training Shakespeare dataset with a Compressive Transformer model." "This is an experiment training an agent to play Atari Breakout game using Proximal Policy Optimization - PPO"
] ]
}, },
{ {
@ -39,48 +25,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "ZCzmCrAIVg0L",
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "ZCzmCrAIVg0L",
"outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee" "outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee"
}, },
"outputs": [],
"source": [ "source": [
"!pip install labml-nn" "!pip install labml-nn"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting labml-nn\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a4/07/d33ead6f84fad2a4e8ff31ccd42864ff7b942785ad9f80d7c98df1c20a02/labml_nn-0.4.94-py3-none-any.whl (171kB)\n",
"\u001b[K |████████████████████████████████| 174kB 16.4MB/s \n",
"\u001b[?25hCollecting labml>=0.4.110\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/eb/c8/98b18d0dda3811998838734f9a32e944397fdd6bb0597cef0ae2b57338e3/labml-0.4.110-py3-none-any.whl (106kB)\n",
"\u001b[K |████████████████████████████████| 112kB 34.7MB/s \n",
"\u001b[?25hCollecting labml-helpers>=0.4.76\n",
" Downloading https://files.pythonhosted.org/packages/49/df/4d920a4a221acd3cfa384dddb909ed0691b08682c0d8aeaabeee2138624f/labml_helpers-0.4.76-py3-none-any.whl\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.19.5)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.8.0+cu101)\n",
"Collecting einops\n",
" Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
"Collecting gitpython\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)\n",
"\u001b[K |████████████████████████████████| 163kB 47.1MB/s \n",
"\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from labml>=0.4.110->labml-nn) (3.13)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
"Collecting gitdb<5,>=4.0.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n",
"\u001b[K |████████████████████████████████| 71kB 11.1MB/s \n",
"\u001b[?25hCollecting smmap<5,>=3.0.1\n",
" Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n",
"Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n",
"Successfully installed einops-0.3.0 gitdb-4.0.7 gitpython-3.1.14 labml-0.4.110 labml-helpers-0.4.76 labml-nn-0.4.94 smmap-4.0.0\n"
],
"name": "stdout"
}
] ]
}, },
{ {
@ -94,16 +49,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "0hJXx_g0wS2C" "id": "0hJXx_g0wS2C"
}, },
"outputs": [],
"source": [ "source": [
"from labml import experiment\n", "from labml import experiment\n",
"from labml.configs import FloatDynamicHyperParam\n", "from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam\n",
"from labml_nn.rl.ppo.experiment import Trainer" "from labml_nn.rl.ppo.experiment import Trainer"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -116,14 +71,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "bFcr9k-l4cAg" "id": "bFcr9k-l4cAg"
}, },
"outputs": [],
"source": [ "source": [
"experiment.create(name=\"ppo\")" "experiment.create(name=\"ppo\")"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -136,15 +91,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "Piz0c5f44hRo" "id": "Piz0c5f44hRo"
}, },
"outputs": [],
"source": [ "source": [
"configs = {\n", "configs = {\n",
" # number of updates\n", " # number of updates\n",
" 'updates': 10000,\n", " 'updates': 10000,\n",
" # number of epochs to train the model with sampled data\n", " # number of epochs to train the model with sampled data\n",
" 'epochs': 4,\n", " 'epochs': IntDynamicHyperParam(8),\n",
" # number of worker processes\n", " # number of worker processes\n",
" 'n_workers': 8,\n", " 'n_workers': 8,\n",
" # number of steps to run on each process for a single update\n", " # number of steps to run on each process for a single update\n",
@ -160,9 +117,7 @@
" # Learning rate\n", " # Learning rate\n",
" 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n", " 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n",
"}" "}"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -175,6 +130,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -183,25 +139,9 @@
"id": "e6hmQhTw4nks", "id": "e6hmQhTw4nks",
"outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de" "outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de"
}, },
"outputs": [],
"source": [ "source": [
"experiment.configs(configs)" "experiment.configs(configs)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<pre style=\"overflow-x: scroll;\"></pre>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
] ]
}, },
{ {
@ -215,9 +155,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "8LB7XVViYuPG" "id": "8LB7XVViYuPG"
}, },
"outputs": [],
"source": [ "source": [
"trainer = Trainer(\n", "trainer = Trainer(\n",
" updates=configs['updates'],\n", " updates=configs['updates'],\n",
@ -230,9 +172,7 @@
" clip_range=configs['clip_range'],\n", " clip_range=configs['clip_range'],\n",
" learning_rate=configs['learning_rate'],\n", " learning_rate=configs['learning_rate'],\n",
")" ")"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -245,26 +185,42 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "aIAWo7Fw5DR8" "id": "aIAWo7Fw5DR8"
}, },
"outputs": [],
"source": [ "source": [
"with experiment.start():\n", "with experiment.start():\n",
" trainer.run_training_loop()" " trainer.run_training_loop()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oBXXlP2b7XZO"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
] ]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Proximal Policy Optimization - PPO",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
} }

View File

@ -19,7 +19,7 @@ from torch import optim
from torch.distributions import Categorical from torch.distributions import Categorical
from labml import monit, tracker, logger, experiment from labml import monit, tracker, logger, experiment
from labml.configs import FloatDynamicHyperParam from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_nn.rl.game import Worker from labml_nn.rl.game import Worker
from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
@ -91,7 +91,8 @@ class Trainer:
""" """
def __init__(self, *, def __init__(self, *,
updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int, updates: int, epochs: IntDynamicHyperParam,
n_workers: int, worker_steps: int, batches: int,
value_loss_coef: FloatDynamicHyperParam, value_loss_coef: FloatDynamicHyperParam,
entropy_bonus_coef: FloatDynamicHyperParam, entropy_bonus_coef: FloatDynamicHyperParam,
clip_range: FloatDynamicHyperParam, clip_range: FloatDynamicHyperParam,
@ -231,7 +232,7 @@ class Trainer:
# the average episode reward does not monotonically increase # the average episode reward does not monotonically increase
# over time. # over time.
# May be reducing the clipping range might solve it. # May be reducing the clipping range might solve it.
for _ in range(self.epochs): for _ in range(self.epochs()):
# shuffle for each epoch # shuffle for each epoch
indexes = torch.randperm(self.batch_size) indexes = torch.randperm(self.batch_size)
@ -356,7 +357,7 @@ def main():
# number of updates # number of updates
'updates': 10000, 'updates': 10000,
# number of epochs to train the model with sampled data # number of epochs to train the model with sampled data
'epochs': 4, 'epochs': IntDynamicHyperParam(8),
# number of worker processes # number of worker processes
'n_workers': 8, 'n_workers': 8,
# number of steps to run on each process for a single update # number of steps to run on each process for a single update
@ -370,7 +371,7 @@ def main():
# Clip range # Clip range
'clip_range': FloatDynamicHyperParam(0.1), 'clip_range': FloatDynamicHyperParam(0.1),
# Learning rate # Learning rate
'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)), 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
} }
experiment.configs(configs) experiment.configs(configs)