diff --git a/tutorials/11 - Deep Q Network/ReplayMemory.ipynb b/tutorials/11 - Deep Q Network/ReplayMemory.ipynb
new file mode 100644
index 0000000..262b483
--- /dev/null
+++ b/tutorials/11 - Deep Q Network/ReplayMemory.ipynb
@@ -0,0 +1,359 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# PyTorch DQN Implemenation\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import gym\n",
+ "import random\n",
+ "import numpy as np\n",
+ "import torchvision.transforms as transforms\n",
+ "import matplotlib.pyplot as plt\n",
+ "from torch.autograd import Variable\n",
+ "from collections import deque, namedtuple"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2017-03-09 21:31:48,174] Making new env: CartPole-v0\n"
+ ]
+ }
+ ],
+ "source": [
+ "env = gym.envs.make(\"CartPole-v0\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "class Net(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(Net, self).__init__()\n",
+ " self.fc1 = nn.Linear(4, 128)\n",
+ " self.tanh = nn.Tanh()\n",
+ " self.fc2 = nn.Linear(128, 2)\n",
+ " self.init_weights()\n",
+ " \n",
+ " def init_weights(self):\n",
+ " self.fc1.weight.data.uniform_(-0.1, 0.1)\n",
+ " self.fc2.weight.data.uniform_(-0.1, 0.1)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " out = self.fc1(x)\n",
+ " out = self.tanh(out)\n",
+ " out = self.fc2(out)\n",
+ " return out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def make_epsilon_greedy_policy(network, epsilon, nA):\n",
+ " def policy(state):\n",
+ " sample = random.random()\n",
+ " if sample < (1-epsilon) + (epsilon/nA):\n",
+ " q_values = network(state.view(1, -1))\n",
+ " action = q_values.data.max(1)[1][0, 0]\n",
+ " else:\n",
+ " action = random.randrange(nA)\n",
+ " return action\n",
+ " return policy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "class ReplayMemory(object):\n",
+ " \n",
+ " def __init__(self, capacity):\n",
+ " self.memory = deque()\n",
+ " self.capacity = capacity\n",
+ " \n",
+ " def push(self, transition):\n",
+ " if len(self.memory) > self.capacity:\n",
+ " self.memory.popleft()\n",
+ " self.memory.append(transition)\n",
+ " \n",
+ " def sample(self, batch_size):\n",
+ " return random.sample(self.memory, batch_size)\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.memory)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "def to_tensor(ndarray, volatile=False):\n",
+ " return Variable(torch.from_numpy(ndarray), volatile=volatile).float()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def deep_q_learning(num_episodes=10, batch_size=100, \n",
+ " discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):\n",
+ "\n",
+ " # Q-Network and memory \n",
+ " net = Net()\n",
+ " memory = ReplayMemory(10000)\n",
+ " \n",
+ " # Loss and Optimizer\n",
+ " criterion = nn.MSELoss()\n",
+ " optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
+ " \n",
+ " for i_episode in range(num_episodes):\n",
+ " \n",
+ " # Set policy (TODO: decaying epsilon)\n",
+ " #if (i_episode+1) % 100 == 0:\n",
+ " # epsilon *= 0.9\n",
+ " \n",
+ " policy = make_epsilon_greedy_policy(\n",
+ " net, epsilon, env.action_space.n)\n",
+ " \n",
+ " # Start an episode\n",
+ " state = env.reset()\n",
+ " \n",
+ " for t in range(10000):\n",
+ " \n",
+ " # Sample action from epsilon greed policy\n",
+ " action = policy(to_tensor(state)) \n",
+ " next_state, reward, done, _ = env.step(action)\n",
+ " \n",
+ " \n",
+ " # Restore transition in memory\n",
+ " memory.push([state, action, reward, next_state])\n",
+ " \n",
+ " \n",
+ " if len(memory) >= batch_size:\n",
+ " # Sample mini-batch transitions from memory\n",
+ " batch = memory.sample(batch_size)\n",
+ " state_batch = np.vstack([trans[0] for trans in batch])\n",
+ " action_batch =np.vstack([trans[1] for trans in batch]) \n",
+ " reward_batch = np.vstack([trans[2] for trans in batch])\n",
+ " next_state_batch = np.vstack([trans[3] for trans in batch])\n",
+ " \n",
+ " # Forward + Backward + Opimize\n",
+ " net.zero_grad()\n",
+ " q_values = net(to_tensor(state_batch))\n",
+ " next_q_values = net(to_tensor(next_state_batch, volatile=True))\n",
+ " next_q_values.volatile = False\n",
+ " \n",
+ " td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]\n",
+ " loss = criterion(q_values.gather(1, \n",
+ " to_tensor(action_batch).long().view(-1, 1)), td_target)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " \n",
+ " if done:\n",
+ " break\n",
+ " \n",
+ " state = next_state\n",
+ " \n",
+ " if len(memory) >= batch_size and (i_episode+1) % 10 == 0:\n",
+ " print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "episode: 9, time: 9, loss: 0.9945\n",
+ "episode: 19, time: 9, loss: 1.8221\n",
+ "episode: 29, time: 9, loss: 4.3124\n",
+ "episode: 39, time: 8, loss: 6.9764\n",
+ "episode: 49, time: 9, loss: 6.8300\n",
+ "episode: 59, time: 8, loss: 5.5186\n",
+ "episode: 69, time: 9, loss: 4.1160\n",
+ "episode: 79, time: 9, loss: 2.4802\n",
+ "episode: 89, time: 13, loss: 0.7890\n",
+ "episode: 99, time: 10, loss: 0.2805\n",
+ "episode: 109, time: 12, loss: 0.1323\n",
+ "episode: 119, time: 13, loss: 0.0519\n",
+ "episode: 129, time: 18, loss: 0.0176\n",
+ "episode: 139, time: 22, loss: 0.0067\n",
+ "episode: 149, time: 17, loss: 0.0114\n",
+ "episode: 159, time: 26, loss: 0.0017\n",
+ "episode: 169, time: 23, loss: 0.0018\n",
+ "episode: 179, time: 21, loss: 0.0023\n",
+ "episode: 189, time: 11, loss: 0.0024\n",
+ "episode: 199, time: 7, loss: 0.0040\n",
+ "episode: 209, time: 8, loss: 0.0030\n",
+ "episode: 219, time: 7, loss: 0.0070\n",
+ "episode: 229, time: 9, loss: 0.0031\n",
+ "episode: 239, time: 9, loss: 0.0029\n",
+ "episode: 249, time: 8, loss: 0.0046\n",
+ "episode: 259, time: 8, loss: 0.0009\n",
+ "episode: 269, time: 10, loss: 0.0020\n",
+ "episode: 279, time: 9, loss: 0.0025\n",
+ "episode: 289, time: 8, loss: 0.0015\n",
+ "episode: 299, time: 10, loss: 0.0009\n",
+ "episode: 309, time: 8, loss: 0.0012\n",
+ "episode: 319, time: 8, loss: 0.0034\n",
+ "episode: 329, time: 8, loss: 0.0008\n",
+ "episode: 339, time: 9, loss: 0.0021\n",
+ "episode: 349, time: 8, loss: 0.0018\n",
+ "episode: 359, time: 9, loss: 0.0017\n",
+ "episode: 369, time: 9, loss: 0.0006\n",
+ "episode: 379, time: 9, loss: 0.0023\n",
+ "episode: 389, time: 10, loss: 0.0017\n",
+ "episode: 399, time: 8, loss: 0.0018\n",
+ "episode: 409, time: 8, loss: 0.0023\n",
+ "episode: 419, time: 9, loss: 0.0020\n",
+ "episode: 429, time: 9, loss: 0.0006\n",
+ "episode: 439, time: 10, loss: 0.0006\n",
+ "episode: 449, time: 10, loss: 0.0025\n",
+ "episode: 459, time: 9, loss: 0.0013\n",
+ "episode: 469, time: 8, loss: 0.0011\n",
+ "episode: 479, time: 8, loss: 0.0005\n",
+ "episode: 489, time: 8, loss: 0.0004\n",
+ "episode: 499, time: 7, loss: 0.0017\n",
+ "episode: 509, time: 7, loss: 0.0004\n",
+ "episode: 519, time: 10, loss: 0.0008\n",
+ "episode: 529, time: 11, loss: 0.0006\n",
+ "episode: 539, time: 9, loss: 0.0010\n",
+ "episode: 549, time: 8, loss: 0.0006\n",
+ "episode: 559, time: 8, loss: 0.0012\n",
+ "episode: 569, time: 9, loss: 0.0011\n",
+ "episode: 579, time: 8, loss: 0.0010\n",
+ "episode: 589, time: 8, loss: 0.0008\n",
+ "episode: 599, time: 10, loss: 0.0010\n",
+ "episode: 609, time: 8, loss: 0.0005\n",
+ "episode: 619, time: 9, loss: 0.0004\n",
+ "episode: 629, time: 8, loss: 0.0007\n",
+ "episode: 639, time: 10, loss: 0.0014\n",
+ "episode: 649, time: 10, loss: 0.0004\n",
+ "episode: 659, time: 9, loss: 0.0008\n",
+ "episode: 669, time: 8, loss: 0.0005\n",
+ "episode: 679, time: 8, loss: 0.0003\n",
+ "episode: 689, time: 9, loss: 0.0009\n",
+ "episode: 699, time: 8, loss: 0.0004\n",
+ "episode: 709, time: 8, loss: 0.0013\n",
+ "episode: 719, time: 8, loss: 0.0006\n",
+ "episode: 729, time: 7, loss: 0.0021\n",
+ "episode: 739, time: 9, loss: 0.0023\n",
+ "episode: 749, time: 9, loss: 0.0039\n",
+ "episode: 759, time: 8, loss: 0.0030\n",
+ "episode: 769, time: 9, loss: 0.0016\n",
+ "episode: 779, time: 7, loss: 0.0041\n",
+ "episode: 789, time: 8, loss: 0.0050\n",
+ "episode: 799, time: 8, loss: 0.0041\n",
+ "episode: 809, time: 11, loss: 0.0053\n",
+ "episode: 819, time: 7, loss: 0.0018\n",
+ "episode: 829, time: 9, loss: 0.0019\n",
+ "episode: 839, time: 11, loss: 0.0017\n",
+ "episode: 849, time: 8, loss: 0.0029\n",
+ "episode: 859, time: 9, loss: 0.0012\n",
+ "episode: 869, time: 9, loss: 0.0036\n",
+ "episode: 879, time: 7, loss: 0.0017\n",
+ "episode: 889, time: 9, loss: 0.0016\n",
+ "episode: 899, time: 10, loss: 0.0023\n",
+ "episode: 909, time: 8, loss: 0.0032\n",
+ "episode: 919, time: 8, loss: 0.0015\n",
+ "episode: 929, time: 9, loss: 0.0021\n",
+ "episode: 939, time: 9, loss: 0.0015\n",
+ "episode: 949, time: 9, loss: 0.0016\n",
+ "episode: 959, time: 9, loss: 0.0013\n",
+ "episode: 969, time: 12, loss: 0.0029\n",
+ "episode: 979, time: 7, loss: 0.0016\n",
+ "episode: 989, time: 7, loss: 0.0012\n",
+ "episode: 999, time: 9, loss: 0.0013\n"
+ ]
+ }
+ ],
+ "source": [
+ "deep_q_learning(1000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "anaconda-cloud": {},
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/tutorials/11 - Deep Q Network/Untitled.ipynb b/tutorials/11 - Deep Q Network/Untitled.ipynb
new file mode 100644
index 0000000..10b10f6
--- /dev/null
+++ b/tutorials/11 - Deep Q Network/Untitled.ipynb
@@ -0,0 +1,154 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "\n",
+ "import gym\n",
+ "import numpy as np\n",
+ "from matplotlib import pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2017-03-08 21:13:15,268] Making new env: Breakout-v0\n"
+ ]
+ },
+ {
+ "ename": "DependencyNotInstalled",
+ "evalue": "No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/atari_env.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0matari_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mImportError\u001b[0m: No module named 'atari_py'",
+ "\nDuring handling of the above exception, another exception occurred:\n",
+ "\u001b[0;31mDependencyNotInstalled\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Breakout-v0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Action space size: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action_meanings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mregistry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, id)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Making new env: %s'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimestep_limit\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'vnc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrappers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime_limit\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTimeLimit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Attempting to make deprecated env {}. (HINT: is there a newer registered version of this env?)'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_entry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mentry_point\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpkg_resources\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEntryPoint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'x={}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mentry_point\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/setuptools-27.2.0-py3.5.egg/pkg_resources/__init__.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, require, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2256\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrequire\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2257\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2258\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2260\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mresolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/setuptools-27.2.0-py3.5.egg/pkg_resources/__init__.py\u001b[0m in \u001b[0;36mresolve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2262\u001b[0m \u001b[0mResolve\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mentry\u001b[0m \u001b[0mpoint\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mits\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mattrs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2263\u001b[0m \"\"\"\n\u001b[0;32m-> 2264\u001b[0;31m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__import__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfromlist\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'__name__'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2265\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2266\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari_env\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAtariEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/atari_env.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0matari_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDependencyNotInstalled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"{}. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mDependencyNotInstalled\u001b[0m: No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)"
+ ]
+ }
+ ],
+ "source": [
+ "env = gym.envs.make(\"Breakout-v0\")\n",
+ "\n",
+ "print(\"Action space size: {}\".format(env.action_space.n))\n",
+ "print(env.get_action_meanings())\n",
+ "\n",
+ "observation = env.reset()\n",
+ "print(\"Observation space shape: {}\".format(observation.shape))\n",
+ "\n",
+ "plt.figure()\n",
+ "plt.imshow(env.render(mode='rgb_array'))\n",
+ "\n",
+ "[env.step(2) for x in range(1)]\n",
+ "plt.figure()\n",
+ "plt.imshow(env.render(mode='rgb_array'))\n",
+ "\n",
+ "env.render(close=True)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2017-03-08 21:12:44,474] Making new env: CartPole-v0\n"
+ ]
+ },
+ {
+ "ename": "NameError",
+ "evalue": "name 'base' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mobservation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobservation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'human'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_close\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/classic_control/cartpole.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscreen_width\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscreen_height\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprefix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Error occured while running `from pyglet.gl import *`\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msuffix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \\\"-screen 0 1400x900x24\\\" python '\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/pyglet/gl/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcarbon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarbonConfig\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mConfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0;32mdel\u001b[0m \u001b[0mbase\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;31m# XXX remove\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'base' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "import gym\n",
+ "env = gym.make('CartPole-v0')\n",
+ "for i_episode in range(20):\n",
+ " observation = env.reset()\n",
+ " for t in range(100):\n",
+ " env.render()\n",
+ " print(observation)\n",
+ " action = env.action_space.sample()\n",
+ " observation, reward, done, info = env.step(action)\n",
+ " if done:\n",
+ " print(\"Episode finished after {} timesteps\".format(t+1))\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "anaconda-cloud": {},
+ "kernelspec": {
+ "display_name": "Python [conda root]",
+ "language": "python",
+ "name": "conda-root-py"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.5.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/tutorials/11 - Deep Q Network/dqn13.py b/tutorials/11 - Deep Q Network/dqn13.py
new file mode 100644
index 0000000..442b609
--- /dev/null
+++ b/tutorials/11 - Deep Q Network/dqn13.py
@@ -0,0 +1,124 @@
+%matplotlib inline
+
+import torch
+import torch.nn as nn
+import gym
+import random
+import numpy as np
+import torchvision.transforms as transforms
+import matplotlib.pyplot as plt
+from torch.autograd import Variable
+from collections import deque, namedtuple
+
+env = gym.envs.make("CartPole-v0")
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.fc1 = nn.Linear(4, 128)
+ self.tanh = nn.Tanh()
+ self.fc2 = nn.Linear(128, 2)
+ self.init_weights()
+
+ def init_weights(self):
+ self.fc1.weight.data.uniform_(-0.1, 0.1)
+ self.fc2.weight.data.uniform_(-0.1, 0.1)
+
+ def forward(self, x):
+ out = self.fc1(x)
+ out = self.tanh(out)
+ out = self.fc2(out)
+ return out
+
+def make_epsilon_greedy_policy(network, epsilon, nA):
+ def policy(state):
+ sample = random.random()
+ if sample < (1-epsilon) + (epsilon/nA):
+ q_values = network(state.view(1, -1))
+ action = q_values.data.max(1)[1][0, 0]
+ else:
+ action = random.randrange(nA)
+ return action
+ return policy
+
+class ReplayMemory(object):
+
+ def __init__(self, capacity):
+ self.memory = deque()
+ self.capacity = capacity
+
+ def push(self, transition):
+ if len(self.memory) > self.capacity:
+ self.memory.popleft()
+ self.memory.append(transition)
+
+ def sample(self, batch_size):
+ return random.sample(self.memory, batch_size)
+
+ def __len__(self):
+ return len(self.memory)
+
+def to_tensor(ndarray, volatile=False):
+ return Variable(torch.from_numpy(ndarray), volatile=volatile).float()
+
+def deep_q_learning(num_episodes=10, batch_size=100,
+ discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):
+
+ # Q-Network and memory
+ net = Net()
+ memory = ReplayMemory(10000)
+
+ # Loss and Optimizer
+ criterion = nn.MSELoss()
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
+
+ for i_episode in range(num_episodes):
+
+ # Set policy (TODO: decaying epsilon)
+ #if (i_episode+1) % 100 == 0:
+ # epsilon *= 0.9
+
+ policy = make_epsilon_greedy_policy(
+ net, epsilon, env.action_space.n)
+
+ # Start an episode
+ state = env.reset()
+
+ for t in range(10000):
+
+ # Sample action from epsilon greed policy
+ action = policy(to_tensor(state))
+ next_state, reward, done, _ = env.step(action)
+
+
+ # Restore transition in memory
+ memory.push([state, action, reward, next_state])
+
+
+ if len(memory) >= batch_size:
+ # Sample mini-batch transitions from memory
+ batch = memory.sample(batch_size)
+ state_batch = np.vstack([trans[0] for trans in batch])
+ action_batch =np.vstack([trans[1] for trans in batch])
+ reward_batch = np.vstack([trans[2] for trans in batch])
+ next_state_batch = np.vstack([trans[3] for trans in batch])
+
+ # Forward + Backward + Opimize
+ net.zero_grad()
+ q_values = net(to_tensor(state_batch))
+ next_q_values = net(to_tensor(next_state_batch, volatile=True))
+ next_q_values.volatile = False
+
+ td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]
+ loss = criterion(q_values.gather(1,
+ to_tensor(action_batch).long().view(-1, 1)), td_target)
+ loss.backward()
+ optimizer.step()
+
+ if done:
+ break
+
+ state = next_state
+
+ if len(memory) >= batch_size and (i_episode+1) % 10 == 0:
+ print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))
\ No newline at end of file