From 8d1be06af5f127d376463da6596857cfc0e417a3 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 2 Oct 2021 14:01:42 +0530 Subject: [PATCH] rl colab notebooks --- labml_nn/rl/dqn/experiment.ipynb | 243 +++++++++++++++++++++++++++++++ labml_nn/rl/ppo/experiment.ipynb | 35 +++-- 2 files changed, 265 insertions(+), 13 deletions(-) create mode 100644 labml_nn/rl/dqn/experiment.ipynb diff --git a/labml_nn/rl/dqn/experiment.ipynb b/labml_nn/rl/dqn/experiment.ipynb new file mode 100644 index 00000000..291080e9 --- /dev/null +++ b/labml_nn/rl/dqn/experiment.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2" + }, + "source": [ + "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb) \n", + "\n", + "## Deep Q Networks (DQN)\n", + "\n", + "This is an experiment training an agent to play Atari Breakout game using Deep Q Networks (DQN)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9" + }, + "source": [ + "Install the `labml-nn` package" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZCzmCrAIVg0L", + "outputId": "6c416266-1e99-4e60-a665-06ff9fba22a6" + }, + "outputs": [], + "source": [ + "!pip install labml-nn" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3-G5kplRFmsO" + }, + "source": [ + "Add Atari ROMs (Doesn't work without this in Google Colab)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SByhklD1FlSj", + "outputId": "74075a5e-ec1c-43dc-8859-8f7c3b3b8402" + }, + "outputs": [], + "source": [ + "! wget http://www.atarimania.com/roms/Roms.rar\n", + "! mkdir /content/ROM/\n", + "! unrar e /content/Roms.rar /content/ROM/\n", + "! python -m atari_py.import_roms /content/ROM/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0hJXx_g0wS2C" + }, + "outputs": [], + "source": [ + "from labml import experiment\n", + "from labml.configs import FloatDynamicHyperParam\n", + "from labml_nn.rl.dqn.experiment import Trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lpggo0wM6qb-" + }, + "source": [ + "Create an experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bFcr9k-l4cAg" + }, + "outputs": [], + "source": [ + "experiment.create(name=\"dqn\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hw6uVl1_GaPv" + }, + "source": [ + "### Configurations\n", + "\n", + "`FloatDynamicHyperParam` is a dynamic hyper-parameter\n", + "that you can change while the experiment is running." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "L8bUtLD6GksC", + "outputId": "c7d4efe7-490e-4153-e691-ca31df1e1275" + }, + "outputs": [], + "source": [ + "configs = {\n", + " # Number of updates\n", + " 'updates': 1_000_000,\n", + " # Number of epochs to train the model with sampled data.\n", + " 'epochs': 8,\n", + " # Number of worker processes\n", + " 'n_workers': 8,\n", + " # Number of steps to run on each process for a single update\n", + " 'worker_steps': 4,\n", + " # Mini batch size\n", + " 'mini_batch_size': 32,\n", + " # Target model updating interval\n", + " 'update_target_model': 250,\n", + " # Learning rate.\n", + " 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set experiment configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "experiment.configs(configs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qYQCFt_JYsjd" + }, + "source": [ + "Create trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8LB7XVViYuPG" + }, + "outputs": [], + "source": [ + "trainer = Trainer(**configs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KJZRf8527GxL" + }, + "source": [ + "Start the experiment and run the training loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 520 + }, + "id": "aIAWo7Fw5DR8", + "outputId": "f2bca844-662d-4bfb-a295-d8529f538eaa" + }, + "outputs": [], + "source": [ + "with experiment.start():\n", + " trainer.run_training_loop()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Deep Q Networks (DQN)", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/labml_nn/rl/ppo/experiment.ipynb b/labml_nn/rl/ppo/experiment.ipynb index c185a1fe..fcd1ecda 100644 --- a/labml_nn/rl/ppo/experiment.ipynb +++ b/labml_nn/rl/ppo/experiment.ipynb @@ -38,6 +38,25 @@ "!pip install labml-nn" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add Atari ROMs (Doesn't work without this in Google Colab)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! wget http://www.atarimania.com/roms/Roms.rar\n", + "! mkdir /content/ROM/\n", + "! unrar e /content/Roms.rar /content/ROM/\n", + "! python -m atari_py.import_roms /content/ROM/" + ] + }, { "cell_type": "markdown", "metadata": { @@ -164,17 +183,7 @@ }, "outputs": [], "source": [ - "trainer = Trainer(\n", - " updates=configs['updates'],\n", - " epochs=configs['epochs'],\n", - " n_workers=configs['n_workers'],\n", - " worker_steps=configs['worker_steps'],\n", - " batches=configs['batches'],\n", - " value_loss_coef=configs['value_loss_coef'],\n", - " entropy_bonus_coef=configs['entropy_bonus_coef'],\n", - " clip_range=configs['clip_range'],\n", - " learning_rate=configs['learning_rate'],\n", - ")" + "trainer = Trainer(**configs)" ] }, { @@ -221,9 +230,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}