mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
notebook
This commit is contained in:
@ -1,18 +1,4 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Proximal Policy Optimization - PPO",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -21,11 +7,11 @@
|
||||
},
|
||||
"source": [
|
||||
"[](https://github.com/lab-ml/nn)\n",
|
||||
"[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb) \n",
|
||||
"[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb) \n",
|
||||
"\n",
|
||||
"## Fast Weights Transformer\n",
|
||||
"## Proximal Policy Optimization - PPO\n",
|
||||
"\n",
|
||||
"This is an experiment training Shakespeare dataset with a Compressive Transformer model."
|
||||
"This is an experiment training an agent to play Atari Breakout game using Proximal Policy Optimization - PPO"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -39,48 +25,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZCzmCrAIVg0L",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ZCzmCrAIVg0L",
|
||||
"outputId": "028e759e-0c9f-472e-b4b8-fdcf3e4604ee"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install labml-nn"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting labml-nn\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a4/07/d33ead6f84fad2a4e8ff31ccd42864ff7b942785ad9f80d7c98df1c20a02/labml_nn-0.4.94-py3-none-any.whl (171kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 174kB 16.4MB/s \n",
|
||||
"\u001b[?25hCollecting labml>=0.4.110\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/eb/c8/98b18d0dda3811998838734f9a32e944397fdd6bb0597cef0ae2b57338e3/labml-0.4.110-py3-none-any.whl (106kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 112kB 34.7MB/s \n",
|
||||
"\u001b[?25hCollecting labml-helpers>=0.4.76\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/49/df/4d920a4a221acd3cfa384dddb909ed0691b08682c0d8aeaabeee2138624f/labml_helpers-0.4.76-py3-none-any.whl\n",
|
||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.19.5)\n",
|
||||
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from labml-nn) (1.8.0+cu101)\n",
|
||||
"Collecting einops\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
|
||||
"Collecting gitpython\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 163kB 47.1MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from labml>=0.4.110->labml-nn) (3.13)\n",
|
||||
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
|
||||
"Collecting gitdb<5,>=4.0.1\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 71kB 11.1MB/s \n",
|
||||
"\u001b[?25hCollecting smmap<5,>=3.0.1\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n",
|
||||
"Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers, einops, labml-nn\n",
|
||||
"Successfully installed einops-0.3.0 gitdb-4.0.7 gitpython-3.1.14 labml-0.4.110 labml-helpers-0.4.76 labml-nn-0.4.94 smmap-4.0.0\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -94,16 +49,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0hJXx_g0wS2C"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from labml import experiment\n",
|
||||
"from labml.configs import FloatDynamicHyperParam\n",
|
||||
"from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam\n",
|
||||
"from labml_nn.rl.ppo.experiment import Trainer"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -116,14 +71,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "bFcr9k-l4cAg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"experiment.create(name=\"ppo\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -136,15 +91,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Piz0c5f44hRo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"configs = {\n",
|
||||
" # number of updates\n",
|
||||
" 'updates': 10000,\n",
|
||||
" # number of epochs to train the model with sampled data\n",
|
||||
" 'epochs': 4,\n",
|
||||
" 'epochs': IntDynamicHyperParam(8),\n",
|
||||
" # number of worker processes\n",
|
||||
" 'n_workers': 8,\n",
|
||||
" # number of steps to run on each process for a single update\n",
|
||||
@ -160,9 +117,7 @@
|
||||
" # Learning rate\n",
|
||||
" 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n",
|
||||
"}"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -175,6 +130,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -183,25 +139,9 @@
|
||||
"id": "e6hmQhTw4nks",
|
||||
"outputId": "0e978879-5dcd-4140-ec53-24a3fbd547de"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"experiment.configs(configs)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"overflow-x: scroll;\"></pre>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -215,9 +155,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "8LB7XVViYuPG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = Trainer(\n",
|
||||
" updates=configs['updates'],\n",
|
||||
@ -230,9 +172,7 @@
|
||||
" clip_range=configs['clip_range'],\n",
|
||||
" learning_rate=configs['learning_rate'],\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -245,26 +185,42 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "aIAWo7Fw5DR8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with experiment.start():\n",
|
||||
" trainer.run_training_loop()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "oBXXlP2b7XZO"
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "Proximal Policy Optimization - PPO",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
@ -19,7 +19,7 @@ from torch import optim
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from labml import monit, tracker, logger, experiment
|
||||
from labml.configs import FloatDynamicHyperParam
|
||||
from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.rl.game import Worker
|
||||
from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
|
||||
@ -91,7 +91,8 @@ class Trainer:
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int,
|
||||
updates: int, epochs: IntDynamicHyperParam,
|
||||
n_workers: int, worker_steps: int, batches: int,
|
||||
value_loss_coef: FloatDynamicHyperParam,
|
||||
entropy_bonus_coef: FloatDynamicHyperParam,
|
||||
clip_range: FloatDynamicHyperParam,
|
||||
@ -231,7 +232,7 @@ class Trainer:
|
||||
# the average episode reward does not monotonically increase
|
||||
# over time.
|
||||
# May be reducing the clipping range might solve it.
|
||||
for _ in range(self.epochs):
|
||||
for _ in range(self.epochs()):
|
||||
# shuffle for each epoch
|
||||
indexes = torch.randperm(self.batch_size)
|
||||
|
||||
@ -356,7 +357,7 @@ def main():
|
||||
# number of updates
|
||||
'updates': 10000,
|
||||
# number of epochs to train the model with sampled data
|
||||
'epochs': 4,
|
||||
'epochs': IntDynamicHyperParam(8),
|
||||
# number of worker processes
|
||||
'n_workers': 8,
|
||||
# number of steps to run on each process for a single update
|
||||
@ -370,7 +371,7 @@ def main():
|
||||
# Clip range
|
||||
'clip_range': FloatDynamicHyperParam(0.1),
|
||||
# Learning rate
|
||||
'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
|
||||
'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
|
||||
}
|
||||
|
||||
experiment.configs(configs)
|
||||
|
||||
Reference in New Issue
Block a user