diff --git a/tutorials/11 - Deep Q Network/ReplayMemory.ipynb b/tutorials/11 - Deep Q Network/ReplayMemory.ipynb new file mode 100644 index 0000000..262b483 --- /dev/null +++ b/tutorials/11 - Deep Q Network/ReplayMemory.ipynb @@ -0,0 +1,359 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PyTorch DQN Implemenation\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import gym\n", + "import random\n", + "import numpy as np\n", + "import torchvision.transforms as transforms\n", + "import matplotlib.pyplot as plt\n", + "from torch.autograd import Variable\n", + "from collections import deque, namedtuple" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2017-03-09 21:31:48,174] Making new env: CartPole-v0\n" + ] + } + ], + "source": [ + "env = gym.envs.make(\"CartPole-v0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.fc1 = nn.Linear(4, 128)\n", + " self.tanh = nn.Tanh()\n", + " self.fc2 = nn.Linear(128, 2)\n", + " self.init_weights()\n", + " \n", + " def init_weights(self):\n", + " self.fc1.weight.data.uniform_(-0.1, 0.1)\n", + " self.fc2.weight.data.uniform_(-0.1, 0.1)\n", + " \n", + " def forward(self, x):\n", + " out = self.fc1(x)\n", + " out = self.tanh(out)\n", + " out = self.fc2(out)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def make_epsilon_greedy_policy(network, epsilon, nA):\n", + " def policy(state):\n", + " sample = random.random()\n", + " if sample < (1-epsilon) + (epsilon/nA):\n", + " q_values = network(state.view(1, -1))\n", + " action = q_values.data.max(1)[1][0, 0]\n", + " else:\n", + " action = random.randrange(nA)\n", + " return action\n", + " return policy" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "class ReplayMemory(object):\n", + " \n", + " def __init__(self, capacity):\n", + " self.memory = deque()\n", + " self.capacity = capacity\n", + " \n", + " def push(self, transition):\n", + " if len(self.memory) > self.capacity:\n", + " self.memory.popleft()\n", + " self.memory.append(transition)\n", + " \n", + " def sample(self, batch_size):\n", + " return random.sample(self.memory, batch_size)\n", + " \n", + " def __len__(self):\n", + " return len(self.memory)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def to_tensor(ndarray, volatile=False):\n", + " return Variable(torch.from_numpy(ndarray), volatile=volatile).float()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def deep_q_learning(num_episodes=10, batch_size=100, \n", + " discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):\n", + "\n", + " # Q-Network and memory \n", + " net = Net()\n", + " memory = ReplayMemory(10000)\n", + " \n", + " # Loss and Optimizer\n", + " criterion = nn.MSELoss()\n", + " optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n", + " \n", + " for i_episode in range(num_episodes):\n", + " \n", + " # Set policy (TODO: decaying epsilon)\n", + " #if (i_episode+1) % 100 == 0:\n", + " # epsilon *= 0.9\n", + " \n", + " policy = make_epsilon_greedy_policy(\n", + " net, epsilon, env.action_space.n)\n", + " \n", + " # Start an episode\n", + " state = env.reset()\n", + " \n", + " for t in range(10000):\n", + " \n", + " # Sample action from epsilon greed policy\n", + " action = policy(to_tensor(state)) \n", + " next_state, reward, done, _ = env.step(action)\n", + " \n", + " \n", + " # Restore transition in memory\n", + " memory.push([state, action, reward, next_state])\n", + " \n", + " \n", + " if len(memory) >= batch_size:\n", + " # Sample mini-batch transitions from memory\n", + " batch = memory.sample(batch_size)\n", + " state_batch = np.vstack([trans[0] for trans in batch])\n", + " action_batch =np.vstack([trans[1] for trans in batch]) \n", + " reward_batch = np.vstack([trans[2] for trans in batch])\n", + " next_state_batch = np.vstack([trans[3] for trans in batch])\n", + " \n", + " # Forward + Backward + Opimize\n", + " net.zero_grad()\n", + " q_values = net(to_tensor(state_batch))\n", + " next_q_values = net(to_tensor(next_state_batch, volatile=True))\n", + " next_q_values.volatile = False\n", + " \n", + " td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]\n", + " loss = criterion(q_values.gather(1, \n", + " to_tensor(action_batch).long().view(-1, 1)), td_target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " if done:\n", + " break\n", + " \n", + " state = next_state\n", + " \n", + " if len(memory) >= batch_size and (i_episode+1) % 10 == 0:\n", + " print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "episode: 9, time: 9, loss: 0.9945\n", + "episode: 19, time: 9, loss: 1.8221\n", + "episode: 29, time: 9, loss: 4.3124\n", + "episode: 39, time: 8, loss: 6.9764\n", + "episode: 49, time: 9, loss: 6.8300\n", + "episode: 59, time: 8, loss: 5.5186\n", + "episode: 69, time: 9, loss: 4.1160\n", + "episode: 79, time: 9, loss: 2.4802\n", + "episode: 89, time: 13, loss: 0.7890\n", + "episode: 99, time: 10, loss: 0.2805\n", + "episode: 109, time: 12, loss: 0.1323\n", + "episode: 119, time: 13, loss: 0.0519\n", + "episode: 129, time: 18, loss: 0.0176\n", + "episode: 139, time: 22, loss: 0.0067\n", + "episode: 149, time: 17, loss: 0.0114\n", + "episode: 159, time: 26, loss: 0.0017\n", + "episode: 169, time: 23, loss: 0.0018\n", + "episode: 179, time: 21, loss: 0.0023\n", + "episode: 189, time: 11, loss: 0.0024\n", + "episode: 199, time: 7, loss: 0.0040\n", + "episode: 209, time: 8, loss: 0.0030\n", + "episode: 219, time: 7, loss: 0.0070\n", + "episode: 229, time: 9, loss: 0.0031\n", + "episode: 239, time: 9, loss: 0.0029\n", + "episode: 249, time: 8, loss: 0.0046\n", + "episode: 259, time: 8, loss: 0.0009\n", + "episode: 269, time: 10, loss: 0.0020\n", + "episode: 279, time: 9, loss: 0.0025\n", + "episode: 289, time: 8, loss: 0.0015\n", + "episode: 299, time: 10, loss: 0.0009\n", + "episode: 309, time: 8, loss: 0.0012\n", + "episode: 319, time: 8, loss: 0.0034\n", + "episode: 329, time: 8, loss: 0.0008\n", + "episode: 339, time: 9, loss: 0.0021\n", + "episode: 349, time: 8, loss: 0.0018\n", + "episode: 359, time: 9, loss: 0.0017\n", + "episode: 369, time: 9, loss: 0.0006\n", + "episode: 379, time: 9, loss: 0.0023\n", + "episode: 389, time: 10, loss: 0.0017\n", + "episode: 399, time: 8, loss: 0.0018\n", + "episode: 409, time: 8, loss: 0.0023\n", + "episode: 419, time: 9, loss: 0.0020\n", + "episode: 429, time: 9, loss: 0.0006\n", + "episode: 439, time: 10, loss: 0.0006\n", + "episode: 449, time: 10, loss: 0.0025\n", + "episode: 459, time: 9, loss: 0.0013\n", + "episode: 469, time: 8, loss: 0.0011\n", + "episode: 479, time: 8, loss: 0.0005\n", + "episode: 489, time: 8, loss: 0.0004\n", + "episode: 499, time: 7, loss: 0.0017\n", + "episode: 509, time: 7, loss: 0.0004\n", + "episode: 519, time: 10, loss: 0.0008\n", + "episode: 529, time: 11, loss: 0.0006\n", + "episode: 539, time: 9, loss: 0.0010\n", + "episode: 549, time: 8, loss: 0.0006\n", + "episode: 559, time: 8, loss: 0.0012\n", + "episode: 569, time: 9, loss: 0.0011\n", + "episode: 579, time: 8, loss: 0.0010\n", + "episode: 589, time: 8, loss: 0.0008\n", + "episode: 599, time: 10, loss: 0.0010\n", + "episode: 609, time: 8, loss: 0.0005\n", + "episode: 619, time: 9, loss: 0.0004\n", + "episode: 629, time: 8, loss: 0.0007\n", + "episode: 639, time: 10, loss: 0.0014\n", + "episode: 649, time: 10, loss: 0.0004\n", + "episode: 659, time: 9, loss: 0.0008\n", + "episode: 669, time: 8, loss: 0.0005\n", + "episode: 679, time: 8, loss: 0.0003\n", + "episode: 689, time: 9, loss: 0.0009\n", + "episode: 699, time: 8, loss: 0.0004\n", + "episode: 709, time: 8, loss: 0.0013\n", + "episode: 719, time: 8, loss: 0.0006\n", + "episode: 729, time: 7, loss: 0.0021\n", + "episode: 739, time: 9, loss: 0.0023\n", + "episode: 749, time: 9, loss: 0.0039\n", + "episode: 759, time: 8, loss: 0.0030\n", + "episode: 769, time: 9, loss: 0.0016\n", + "episode: 779, time: 7, loss: 0.0041\n", + "episode: 789, time: 8, loss: 0.0050\n", + "episode: 799, time: 8, loss: 0.0041\n", + "episode: 809, time: 11, loss: 0.0053\n", + "episode: 819, time: 7, loss: 0.0018\n", + "episode: 829, time: 9, loss: 0.0019\n", + "episode: 839, time: 11, loss: 0.0017\n", + "episode: 849, time: 8, loss: 0.0029\n", + "episode: 859, time: 9, loss: 0.0012\n", + "episode: 869, time: 9, loss: 0.0036\n", + "episode: 879, time: 7, loss: 0.0017\n", + "episode: 889, time: 9, loss: 0.0016\n", + "episode: 899, time: 10, loss: 0.0023\n", + "episode: 909, time: 8, loss: 0.0032\n", + "episode: 919, time: 8, loss: 0.0015\n", + "episode: 929, time: 9, loss: 0.0021\n", + "episode: 939, time: 9, loss: 0.0015\n", + "episode: 949, time: 9, loss: 0.0016\n", + "episode: 959, time: 9, loss: 0.0013\n", + "episode: 969, time: 12, loss: 0.0029\n", + "episode: 979, time: 7, loss: 0.0016\n", + "episode: 989, time: 7, loss: 0.0012\n", + "episode: 999, time: 9, loss: 0.0013\n" + ] + } + ], + "source": [ + "deep_q_learning(1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/tutorials/11 - Deep Q Network/Untitled.ipynb b/tutorials/11 - Deep Q Network/Untitled.ipynb new file mode 100644 index 0000000..10b10f6 --- /dev/null +++ b/tutorials/11 - Deep Q Network/Untitled.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import gym\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2017-03-08 21:13:15,268] Making new env: Breakout-v0\n" + ] + }, + { + "ename": "DependencyNotInstalled", + "evalue": "No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/atari_env.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0matari_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mImportError\u001b[0m: No module named 'atari_py'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mDependencyNotInstalled\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Breakout-v0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Action space size: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action_meanings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mregistry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, id)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Making new env: %s'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimestep_limit\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'vnc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrappers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime_limit\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTimeLimit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Attempting to make deprecated env {}. (HINT: is there a newer registered version of this env?)'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_entry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mentry_point\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpkg_resources\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEntryPoint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'x={}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mentry_point\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/setuptools-27.2.0-py3.5.egg/pkg_resources/__init__.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, require, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2256\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrequire\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2257\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2258\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2260\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mresolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/setuptools-27.2.0-py3.5.egg/pkg_resources/__init__.py\u001b[0m in \u001b[0;36mresolve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2262\u001b[0m \u001b[0mResolve\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mentry\u001b[0m \u001b[0mpoint\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mits\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mattrs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2263\u001b[0m \"\"\"\n\u001b[0;32m-> 2264\u001b[0;31m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__import__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfromlist\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'__name__'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2265\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2266\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari_env\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAtariEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/atari_env.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0matari_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDependencyNotInstalled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"{}. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDependencyNotInstalled\u001b[0m: No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)" + ] + } + ], + "source": [ + "env = gym.envs.make(\"Breakout-v0\")\n", + "\n", + "print(\"Action space size: {}\".format(env.action_space.n))\n", + "print(env.get_action_meanings())\n", + "\n", + "observation = env.reset()\n", + "print(\"Observation space shape: {}\".format(observation.shape))\n", + "\n", + "plt.figure()\n", + "plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + "[env.step(2) for x in range(1)]\n", + "plt.figure()\n", + "plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + "env.render(close=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2017-03-08 21:12:44,474] Making new env: CartPole-v0\n" + ] + }, + { + "ename": "NameError", + "evalue": "name 'base' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mobservation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobservation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'human'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_close\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/classic_control/cartpole.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscreen_width\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscreen_height\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprefix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Error occured while running `from pyglet.gl import *`\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msuffix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \\\"-screen 0 1400x900x24\\\" python '\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/pyglet/gl/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcarbon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarbonConfig\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mConfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0;32mdel\u001b[0m \u001b[0mbase\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;31m# XXX remove\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'base' is not defined" + ] + } + ], + "source": [ + "import gym\n", + "env = gym.make('CartPole-v0')\n", + "for i_episode in range(20):\n", + " observation = env.reset()\n", + " for t in range(100):\n", + " env.render()\n", + " print(observation)\n", + " action = env.action_space.sample()\n", + " observation, reward, done, info = env.step(action)\n", + " if done:\n", + " print(\"Episode finished after {} timesteps\".format(t+1))\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python [conda root]", + "language": "python", + "name": "conda-root-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/tutorials/11 - Deep Q Network/dqn13.py b/tutorials/11 - Deep Q Network/dqn13.py new file mode 100644 index 0000000..442b609 --- /dev/null +++ b/tutorials/11 - Deep Q Network/dqn13.py @@ -0,0 +1,124 @@ +%matplotlib inline + +import torch +import torch.nn as nn +import gym +import random +import numpy as np +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +from torch.autograd import Variable +from collections import deque, namedtuple + +env = gym.envs.make("CartPole-v0") + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(4, 128) + self.tanh = nn.Tanh() + self.fc2 = nn.Linear(128, 2) + self.init_weights() + + def init_weights(self): + self.fc1.weight.data.uniform_(-0.1, 0.1) + self.fc2.weight.data.uniform_(-0.1, 0.1) + + def forward(self, x): + out = self.fc1(x) + out = self.tanh(out) + out = self.fc2(out) + return out + +def make_epsilon_greedy_policy(network, epsilon, nA): + def policy(state): + sample = random.random() + if sample < (1-epsilon) + (epsilon/nA): + q_values = network(state.view(1, -1)) + action = q_values.data.max(1)[1][0, 0] + else: + action = random.randrange(nA) + return action + return policy + +class ReplayMemory(object): + + def __init__(self, capacity): + self.memory = deque() + self.capacity = capacity + + def push(self, transition): + if len(self.memory) > self.capacity: + self.memory.popleft() + self.memory.append(transition) + + def sample(self, batch_size): + return random.sample(self.memory, batch_size) + + def __len__(self): + return len(self.memory) + +def to_tensor(ndarray, volatile=False): + return Variable(torch.from_numpy(ndarray), volatile=volatile).float() + +def deep_q_learning(num_episodes=10, batch_size=100, + discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95): + + # Q-Network and memory + net = Net() + memory = ReplayMemory(10000) + + # Loss and Optimizer + criterion = nn.MSELoss() + optimizer = torch.optim.Adam(net.parameters(), lr=0.001) + + for i_episode in range(num_episodes): + + # Set policy (TODO: decaying epsilon) + #if (i_episode+1) % 100 == 0: + # epsilon *= 0.9 + + policy = make_epsilon_greedy_policy( + net, epsilon, env.action_space.n) + + # Start an episode + state = env.reset() + + for t in range(10000): + + # Sample action from epsilon greed policy + action = policy(to_tensor(state)) + next_state, reward, done, _ = env.step(action) + + + # Restore transition in memory + memory.push([state, action, reward, next_state]) + + + if len(memory) >= batch_size: + # Sample mini-batch transitions from memory + batch = memory.sample(batch_size) + state_batch = np.vstack([trans[0] for trans in batch]) + action_batch =np.vstack([trans[1] for trans in batch]) + reward_batch = np.vstack([trans[2] for trans in batch]) + next_state_batch = np.vstack([trans[3] for trans in batch]) + + # Forward + Backward + Opimize + net.zero_grad() + q_values = net(to_tensor(state_batch)) + next_q_values = net(to_tensor(next_state_batch, volatile=True)) + next_q_values.volatile = False + + td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0] + loss = criterion(q_values.gather(1, + to_tensor(action_batch).long().view(-1, 1)), td_target) + loss.backward() + optimizer.step() + + if done: + break + + state = next_state + + if len(memory) >= batch_size and (i_episode+1) % 10 == 0: + print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0])) \ No newline at end of file