{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "AYV_dMVDxyc2" }, "source": [ "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb) \n", "\n", "## Deep Q Networks (DQN)\n", "\n", "This is an experiment training an agent to play Atari Breakout game using Deep Q Networks (DQN)" ] }, { "cell_type": "markdown", "metadata": { "id": "AahG_i2y5tY9" }, "source": [ "Install the `labml-nn` package" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZCzmCrAIVg0L", "outputId": "6c416266-1e99-4e60-a665-06ff9fba22a6" }, "outputs": [], "source": [ "!pip install labml-nn" ] }, { "cell_type": "markdown", "metadata": { "id": "3-G5kplRFmsO" }, "source": [ "Add Atari ROMs (Doesn't work without this in Google Colab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SByhklD1FlSj", "outputId": "74075a5e-ec1c-43dc-8859-8f7c3b3b8402" }, "outputs": [], "source": [ "! wget http://www.atarimania.com/roms/Roms.rar\n", "! mkdir /content/ROM/\n", "! unrar e /content/Roms.rar /content/ROM/\n", "! python -m atari_py.import_roms /content/ROM/" ] }, { "cell_type": "markdown", "metadata": { "id": "SE2VUQ6L5zxI" }, "source": [ "Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hJXx_g0wS2C" }, "outputs": [], "source": [ "from labml import experiment\n", "from labml.configs import FloatDynamicHyperParam\n", "from labml_nn.rl.dqn.experiment import Trainer" ] }, { "cell_type": "markdown", "metadata": { "id": "Lpggo0wM6qb-" }, "source": [ "Create an experiment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bFcr9k-l4cAg" }, "outputs": [], "source": [ "experiment.create(name=\"dqn\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Hw6uVl1_GaPv" }, "source": [ "### Configurations\n", "\n", "`FloatDynamicHyperParam` is a dynamic hyper-parameter\n", "that you can change while the experiment is running." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "id": "L8bUtLD6GksC", "outputId": "c7d4efe7-490e-4153-e691-ca31df1e1275" }, "outputs": [], "source": [ "configs = {\n", " # Number of updates\n", " 'updates': 1_000_000,\n", " # Number of epochs to train the model with sampled data.\n", " 'epochs': 8,\n", " # Number of worker processes\n", " 'n_workers': 8,\n", " # Number of steps to run on each process for a single update\n", " 'worker_steps': 4,\n", " # Mini batch size\n", " 'mini_batch_size': 32,\n", " # Target model updating interval\n", " 'update_target_model': 250,\n", " # Learning rate.\n", " 'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set experiment configurations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "experiment.configs(configs)" ] }, { "cell_type": "markdown", "metadata": { "id": "qYQCFt_JYsjd" }, "source": [ "Create trainer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8LB7XVViYuPG" }, "outputs": [], "source": [ "trainer = Trainer(**configs)" ] }, { "cell_type": "markdown", "metadata": { "id": "KJZRf8527GxL" }, "source": [ "Start the experiment and run the training loop." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 520 }, "id": "aIAWo7Fw5DR8", "outputId": "f2bca844-662d-4bfb-a295-d8529f538eaa" }, "outputs": [], "source": [ "with experiment.start():\n", " trainer.run_training_loop()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "Deep Q Networks (DQN)", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }