mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	rl colab notebooks
This commit is contained in:
		
							
								
								
									
										243
									
								
								labml_nn/rl/dqn/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										243
									
								
								labml_nn/rl/dqn/experiment.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,243 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					 "cells": [
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "AYV_dMVDxyc2"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "[](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
 | 
				
			||||||
 | 
					    "[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb)                    \n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "## Deep Q Networks (DQN)\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "This is an experiment training an agent to play Atari Breakout game using Deep Q Networks (DQN)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "AahG_i2y5tY9"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Install the `labml-nn` package"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "colab": {
 | 
				
			||||||
 | 
					     "base_uri": "https://localhost:8080/"
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "id": "ZCzmCrAIVg0L",
 | 
				
			||||||
 | 
					    "outputId": "6c416266-1e99-4e60-a665-06ff9fba22a6"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "!pip install labml-nn"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "3-G5kplRFmsO"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Add Atari ROMs (Doesn't work without this in Google Colab)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "colab": {
 | 
				
			||||||
 | 
					     "base_uri": "https://localhost:8080/"
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "id": "SByhklD1FlSj",
 | 
				
			||||||
 | 
					    "outputId": "74075a5e-ec1c-43dc-8859-8f7c3b3b8402"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "! wget http://www.atarimania.com/roms/Roms.rar\n",
 | 
				
			||||||
 | 
					    "! mkdir /content/ROM/\n",
 | 
				
			||||||
 | 
					    "! unrar e /content/Roms.rar /content/ROM/\n",
 | 
				
			||||||
 | 
					    "! python -m atari_py.import_roms /content/ROM/"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "SE2VUQ6L5zxI"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Imports"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "0hJXx_g0wS2C"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "from labml import experiment\n",
 | 
				
			||||||
 | 
					    "from labml.configs import FloatDynamicHyperParam\n",
 | 
				
			||||||
 | 
					    "from labml_nn.rl.dqn.experiment import Trainer"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "Lpggo0wM6qb-"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Create an experiment"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "bFcr9k-l4cAg"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "experiment.create(name=\"dqn\")"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "Hw6uVl1_GaPv"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "### Configurations\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "`FloatDynamicHyperParam` is a dynamic hyper-parameter\n",
 | 
				
			||||||
 | 
					    "that you can change while the experiment is running."
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "colab": {
 | 
				
			||||||
 | 
					     "base_uri": "https://localhost:8080/",
 | 
				
			||||||
 | 
					     "height": 17
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "id": "L8bUtLD6GksC",
 | 
				
			||||||
 | 
					    "outputId": "c7d4efe7-490e-4153-e691-ca31df1e1275"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "configs = {\n",
 | 
				
			||||||
 | 
					    "    # Number of updates\n",
 | 
				
			||||||
 | 
					    "    'updates': 1_000_000,\n",
 | 
				
			||||||
 | 
					    "    # Number of epochs to train the model with sampled data.\n",
 | 
				
			||||||
 | 
					    "    'epochs': 8,\n",
 | 
				
			||||||
 | 
					    "    # Number of worker processes\n",
 | 
				
			||||||
 | 
					    "    'n_workers': 8,\n",
 | 
				
			||||||
 | 
					    "    # Number of steps to run on each process for a single update\n",
 | 
				
			||||||
 | 
					    "    'worker_steps': 4,\n",
 | 
				
			||||||
 | 
					    "    # Mini batch size\n",
 | 
				
			||||||
 | 
					    "    'mini_batch_size': 32,\n",
 | 
				
			||||||
 | 
					    "    # Target model updating interval\n",
 | 
				
			||||||
 | 
					    "    'update_target_model': 250,\n",
 | 
				
			||||||
 | 
					    "    # Learning rate.\n",
 | 
				
			||||||
 | 
					    "    'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),\n",
 | 
				
			||||||
 | 
					    "}"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Set experiment configurations"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "experiment.configs(configs)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "qYQCFt_JYsjd"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Create trainer"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "8LB7XVViYuPG"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "trainer = Trainer(**configs)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "id": "KJZRf8527GxL"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Start the experiment and run the training loop."
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {
 | 
				
			||||||
 | 
					    "colab": {
 | 
				
			||||||
 | 
					     "base_uri": "https://localhost:8080/",
 | 
				
			||||||
 | 
					     "height": 520
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "id": "aIAWo7Fw5DR8",
 | 
				
			||||||
 | 
					    "outputId": "f2bca844-662d-4bfb-a295-d8529f538eaa"
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "with experiment.start():\n",
 | 
				
			||||||
 | 
					    "    trainer.run_training_loop()"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					 ],
 | 
				
			||||||
 | 
					 "metadata": {
 | 
				
			||||||
 | 
					  "accelerator": "GPU",
 | 
				
			||||||
 | 
					  "colab": {
 | 
				
			||||||
 | 
					   "collapsed_sections": [],
 | 
				
			||||||
 | 
					   "name": "Deep Q Networks (DQN)",
 | 
				
			||||||
 | 
					   "provenance": []
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "kernelspec": {
 | 
				
			||||||
 | 
					   "display_name": "Python 3",
 | 
				
			||||||
 | 
					   "language": "python",
 | 
				
			||||||
 | 
					   "name": "python3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "language_info": {
 | 
				
			||||||
 | 
					   "codemirror_mode": {
 | 
				
			||||||
 | 
					    "name": "ipython",
 | 
				
			||||||
 | 
					    "version": 3
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "file_extension": ".py",
 | 
				
			||||||
 | 
					   "mimetype": "text/x-python",
 | 
				
			||||||
 | 
					   "name": "python",
 | 
				
			||||||
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
 | 
					   "version": "3.7.5"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					 },
 | 
				
			||||||
 | 
					 "nbformat": 4,
 | 
				
			||||||
 | 
					 "nbformat_minor": 4
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -38,6 +38,25 @@
 | 
				
			|||||||
    "!pip install labml-nn"
 | 
					    "!pip install labml-nn"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "Add Atari ROMs (Doesn't work without this in Google Colab)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "! wget http://www.atarimania.com/roms/Roms.rar\n",
 | 
				
			||||||
 | 
					    "! mkdir /content/ROM/\n",
 | 
				
			||||||
 | 
					    "! unrar e /content/Roms.rar /content/ROM/\n",
 | 
				
			||||||
 | 
					    "! python -m atari_py.import_roms /content/ROM/"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
   "cell_type": "markdown",
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
   "metadata": {
 | 
					   "metadata": {
 | 
				
			||||||
@ -164,17 +183,7 @@
 | 
				
			|||||||
   },
 | 
					   },
 | 
				
			||||||
   "outputs": [],
 | 
					   "outputs": [],
 | 
				
			||||||
   "source": [
 | 
					   "source": [
 | 
				
			||||||
    "trainer = Trainer(\n",
 | 
					    "trainer = Trainer(**configs)"
 | 
				
			||||||
    "    updates=configs['updates'],\n",
 | 
					 | 
				
			||||||
    "    epochs=configs['epochs'],\n",
 | 
					 | 
				
			||||||
    "    n_workers=configs['n_workers'],\n",
 | 
					 | 
				
			||||||
    "    worker_steps=configs['worker_steps'],\n",
 | 
					 | 
				
			||||||
    "    batches=configs['batches'],\n",
 | 
					 | 
				
			||||||
    "    value_loss_coef=configs['value_loss_coef'],\n",
 | 
					 | 
				
			||||||
    "    entropy_bonus_coef=configs['entropy_bonus_coef'],\n",
 | 
					 | 
				
			||||||
    "    clip_range=configs['clip_range'],\n",
 | 
					 | 
				
			||||||
    "    learning_rate=configs['learning_rate'],\n",
 | 
					 | 
				
			||||||
    ")"
 | 
					 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
@ -221,7 +230,7 @@
 | 
				
			|||||||
   "name": "python",
 | 
					   "name": "python",
 | 
				
			||||||
   "nbconvert_exporter": "python",
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
   "pygments_lexer": "ipython3",
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
   "version": "3.8.5"
 | 
					   "version": "3.7.5"
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 },
 | 
					 },
 | 
				
			||||||
 "nbformat": 4,
 | 
					 "nbformat": 4,
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user