mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-28 20:53:44 +08:00
add dqn
This commit is contained in:
359
tutorials/11 - Deep Q Network/ReplayMemory.ipynb
Normal file
359
tutorials/11 - Deep Q Network/ReplayMemory.ipynb
Normal file
@ -0,0 +1,359 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PyTorch DQN Implemenation\n",
|
||||
"\n",
|
||||
"<br/>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import gym\n",
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from torch.autograd import Variable\n",
|
||||
"from collections import deque, namedtuple"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2017-03-09 21:31:48,174] Making new env: CartPole-v0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym.envs.make(\"CartPole-v0\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Net(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Net, self).__init__()\n",
|
||||
" self.fc1 = nn.Linear(4, 128)\n",
|
||||
" self.tanh = nn.Tanh()\n",
|
||||
" self.fc2 = nn.Linear(128, 2)\n",
|
||||
" self.init_weights()\n",
|
||||
" \n",
|
||||
" def init_weights(self):\n",
|
||||
" self.fc1.weight.data.uniform_(-0.1, 0.1)\n",
|
||||
" self.fc2.weight.data.uniform_(-0.1, 0.1)\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" out = self.fc1(x)\n",
|
||||
" out = self.tanh(out)\n",
|
||||
" out = self.fc2(out)\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def make_epsilon_greedy_policy(network, epsilon, nA):\n",
|
||||
" def policy(state):\n",
|
||||
" sample = random.random()\n",
|
||||
" if sample < (1-epsilon) + (epsilon/nA):\n",
|
||||
" q_values = network(state.view(1, -1))\n",
|
||||
" action = q_values.data.max(1)[1][0, 0]\n",
|
||||
" else:\n",
|
||||
" action = random.randrange(nA)\n",
|
||||
" return action\n",
|
||||
" return policy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ReplayMemory(object):\n",
|
||||
" \n",
|
||||
" def __init__(self, capacity):\n",
|
||||
" self.memory = deque()\n",
|
||||
" self.capacity = capacity\n",
|
||||
" \n",
|
||||
" def push(self, transition):\n",
|
||||
" if len(self.memory) > self.capacity:\n",
|
||||
" self.memory.popleft()\n",
|
||||
" self.memory.append(transition)\n",
|
||||
" \n",
|
||||
" def sample(self, batch_size):\n",
|
||||
" return random.sample(self.memory, batch_size)\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def to_tensor(ndarray, volatile=False):\n",
|
||||
" return Variable(torch.from_numpy(ndarray), volatile=volatile).float()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def deep_q_learning(num_episodes=10, batch_size=100, \n",
|
||||
" discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):\n",
|
||||
"\n",
|
||||
" # Q-Network and memory \n",
|
||||
" net = Net()\n",
|
||||
" memory = ReplayMemory(10000)\n",
|
||||
" \n",
|
||||
" # Loss and Optimizer\n",
|
||||
" criterion = nn.MSELoss()\n",
|
||||
" optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
|
||||
" \n",
|
||||
" for i_episode in range(num_episodes):\n",
|
||||
" \n",
|
||||
" # Set policy (TODO: decaying epsilon)\n",
|
||||
" #if (i_episode+1) % 100 == 0:\n",
|
||||
" # epsilon *= 0.9\n",
|
||||
" \n",
|
||||
" policy = make_epsilon_greedy_policy(\n",
|
||||
" net, epsilon, env.action_space.n)\n",
|
||||
" \n",
|
||||
" # Start an episode\n",
|
||||
" state = env.reset()\n",
|
||||
" \n",
|
||||
" for t in range(10000):\n",
|
||||
" \n",
|
||||
" # Sample action from epsilon greed policy\n",
|
||||
" action = policy(to_tensor(state)) \n",
|
||||
" next_state, reward, done, _ = env.step(action)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # Restore transition in memory\n",
|
||||
" memory.push([state, action, reward, next_state])\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" if len(memory) >= batch_size:\n",
|
||||
" # Sample mini-batch transitions from memory\n",
|
||||
" batch = memory.sample(batch_size)\n",
|
||||
" state_batch = np.vstack([trans[0] for trans in batch])\n",
|
||||
" action_batch =np.vstack([trans[1] for trans in batch]) \n",
|
||||
" reward_batch = np.vstack([trans[2] for trans in batch])\n",
|
||||
" next_state_batch = np.vstack([trans[3] for trans in batch])\n",
|
||||
" \n",
|
||||
" # Forward + Backward + Opimize\n",
|
||||
" net.zero_grad()\n",
|
||||
" q_values = net(to_tensor(state_batch))\n",
|
||||
" next_q_values = net(to_tensor(next_state_batch, volatile=True))\n",
|
||||
" next_q_values.volatile = False\n",
|
||||
" \n",
|
||||
" td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]\n",
|
||||
" loss = criterion(q_values.gather(1, \n",
|
||||
" to_tensor(action_batch).long().view(-1, 1)), td_target)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" if done:\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" state = next_state\n",
|
||||
" \n",
|
||||
" if len(memory) >= batch_size and (i_episode+1) % 10 == 0:\n",
|
||||
" print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"episode: 9, time: 9, loss: 0.9945\n",
|
||||
"episode: 19, time: 9, loss: 1.8221\n",
|
||||
"episode: 29, time: 9, loss: 4.3124\n",
|
||||
"episode: 39, time: 8, loss: 6.9764\n",
|
||||
"episode: 49, time: 9, loss: 6.8300\n",
|
||||
"episode: 59, time: 8, loss: 5.5186\n",
|
||||
"episode: 69, time: 9, loss: 4.1160\n",
|
||||
"episode: 79, time: 9, loss: 2.4802\n",
|
||||
"episode: 89, time: 13, loss: 0.7890\n",
|
||||
"episode: 99, time: 10, loss: 0.2805\n",
|
||||
"episode: 109, time: 12, loss: 0.1323\n",
|
||||
"episode: 119, time: 13, loss: 0.0519\n",
|
||||
"episode: 129, time: 18, loss: 0.0176\n",
|
||||
"episode: 139, time: 22, loss: 0.0067\n",
|
||||
"episode: 149, time: 17, loss: 0.0114\n",
|
||||
"episode: 159, time: 26, loss: 0.0017\n",
|
||||
"episode: 169, time: 23, loss: 0.0018\n",
|
||||
"episode: 179, time: 21, loss: 0.0023\n",
|
||||
"episode: 189, time: 11, loss: 0.0024\n",
|
||||
"episode: 199, time: 7, loss: 0.0040\n",
|
||||
"episode: 209, time: 8, loss: 0.0030\n",
|
||||
"episode: 219, time: 7, loss: 0.0070\n",
|
||||
"episode: 229, time: 9, loss: 0.0031\n",
|
||||
"episode: 239, time: 9, loss: 0.0029\n",
|
||||
"episode: 249, time: 8, loss: 0.0046\n",
|
||||
"episode: 259, time: 8, loss: 0.0009\n",
|
||||
"episode: 269, time: 10, loss: 0.0020\n",
|
||||
"episode: 279, time: 9, loss: 0.0025\n",
|
||||
"episode: 289, time: 8, loss: 0.0015\n",
|
||||
"episode: 299, time: 10, loss: 0.0009\n",
|
||||
"episode: 309, time: 8, loss: 0.0012\n",
|
||||
"episode: 319, time: 8, loss: 0.0034\n",
|
||||
"episode: 329, time: 8, loss: 0.0008\n",
|
||||
"episode: 339, time: 9, loss: 0.0021\n",
|
||||
"episode: 349, time: 8, loss: 0.0018\n",
|
||||
"episode: 359, time: 9, loss: 0.0017\n",
|
||||
"episode: 369, time: 9, loss: 0.0006\n",
|
||||
"episode: 379, time: 9, loss: 0.0023\n",
|
||||
"episode: 389, time: 10, loss: 0.0017\n",
|
||||
"episode: 399, time: 8, loss: 0.0018\n",
|
||||
"episode: 409, time: 8, loss: 0.0023\n",
|
||||
"episode: 419, time: 9, loss: 0.0020\n",
|
||||
"episode: 429, time: 9, loss: 0.0006\n",
|
||||
"episode: 439, time: 10, loss: 0.0006\n",
|
||||
"episode: 449, time: 10, loss: 0.0025\n",
|
||||
"episode: 459, time: 9, loss: 0.0013\n",
|
||||
"episode: 469, time: 8, loss: 0.0011\n",
|
||||
"episode: 479, time: 8, loss: 0.0005\n",
|
||||
"episode: 489, time: 8, loss: 0.0004\n",
|
||||
"episode: 499, time: 7, loss: 0.0017\n",
|
||||
"episode: 509, time: 7, loss: 0.0004\n",
|
||||
"episode: 519, time: 10, loss: 0.0008\n",
|
||||
"episode: 529, time: 11, loss: 0.0006\n",
|
||||
"episode: 539, time: 9, loss: 0.0010\n",
|
||||
"episode: 549, time: 8, loss: 0.0006\n",
|
||||
"episode: 559, time: 8, loss: 0.0012\n",
|
||||
"episode: 569, time: 9, loss: 0.0011\n",
|
||||
"episode: 579, time: 8, loss: 0.0010\n",
|
||||
"episode: 589, time: 8, loss: 0.0008\n",
|
||||
"episode: 599, time: 10, loss: 0.0010\n",
|
||||
"episode: 609, time: 8, loss: 0.0005\n",
|
||||
"episode: 619, time: 9, loss: 0.0004\n",
|
||||
"episode: 629, time: 8, loss: 0.0007\n",
|
||||
"episode: 639, time: 10, loss: 0.0014\n",
|
||||
"episode: 649, time: 10, loss: 0.0004\n",
|
||||
"episode: 659, time: 9, loss: 0.0008\n",
|
||||
"episode: 669, time: 8, loss: 0.0005\n",
|
||||
"episode: 679, time: 8, loss: 0.0003\n",
|
||||
"episode: 689, time: 9, loss: 0.0009\n",
|
||||
"episode: 699, time: 8, loss: 0.0004\n",
|
||||
"episode: 709, time: 8, loss: 0.0013\n",
|
||||
"episode: 719, time: 8, loss: 0.0006\n",
|
||||
"episode: 729, time: 7, loss: 0.0021\n",
|
||||
"episode: 739, time: 9, loss: 0.0023\n",
|
||||
"episode: 749, time: 9, loss: 0.0039\n",
|
||||
"episode: 759, time: 8, loss: 0.0030\n",
|
||||
"episode: 769, time: 9, loss: 0.0016\n",
|
||||
"episode: 779, time: 7, loss: 0.0041\n",
|
||||
"episode: 789, time: 8, loss: 0.0050\n",
|
||||
"episode: 799, time: 8, loss: 0.0041\n",
|
||||
"episode: 809, time: 11, loss: 0.0053\n",
|
||||
"episode: 819, time: 7, loss: 0.0018\n",
|
||||
"episode: 829, time: 9, loss: 0.0019\n",
|
||||
"episode: 839, time: 11, loss: 0.0017\n",
|
||||
"episode: 849, time: 8, loss: 0.0029\n",
|
||||
"episode: 859, time: 9, loss: 0.0012\n",
|
||||
"episode: 869, time: 9, loss: 0.0036\n",
|
||||
"episode: 879, time: 7, loss: 0.0017\n",
|
||||
"episode: 889, time: 9, loss: 0.0016\n",
|
||||
"episode: 899, time: 10, loss: 0.0023\n",
|
||||
"episode: 909, time: 8, loss: 0.0032\n",
|
||||
"episode: 919, time: 8, loss: 0.0015\n",
|
||||
"episode: 929, time: 9, loss: 0.0021\n",
|
||||
"episode: 939, time: 9, loss: 0.0015\n",
|
||||
"episode: 949, time: 9, loss: 0.0016\n",
|
||||
"episode: 959, time: 9, loss: 0.0013\n",
|
||||
"episode: 969, time: 12, loss: 0.0029\n",
|
||||
"episode: 979, time: 7, loss: 0.0016\n",
|
||||
"episode: 989, time: 7, loss: 0.0012\n",
|
||||
"episode: 999, time: 9, loss: 0.0013\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"deep_q_learning(1000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"anaconda-cloud": {},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 2",
|
||||
"language": "python",
|
||||
"name": "python2"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
154
tutorials/11 - Deep Q Network/Untitled.ipynb
Normal file
154
tutorials/11 - Deep Q Network/Untitled.ipynb
Normal file
@ -0,0 +1,154 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"from matplotlib import pyplot as plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2017-03-08 21:13:15,268] Making new env: Breakout-v0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "DependencyNotInstalled",
|
||||
"evalue": "No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/atari_env.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0matari_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mImportError\u001b[0m: No module named 'atari_py'",
|
||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||
"\u001b[0;31mDependencyNotInstalled\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-6-fd0311e5e366>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Breakout-v0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Action space size: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action_meanings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mregistry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, id)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Making new env: %s'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimestep_limit\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'vnc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrappers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime_limit\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTimeLimit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Attempting to make deprecated env {}. (HINT: is there a newer registered version of this env?)'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_entry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mentry_point\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpkg_resources\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEntryPoint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'x={}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mentry_point\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/setuptools-27.2.0-py3.5.egg/pkg_resources/__init__.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, require, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2256\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrequire\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2257\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2258\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2260\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mresolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/setuptools-27.2.0-py3.5.egg/pkg_resources/__init__.py\u001b[0m in \u001b[0;36mresolve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2262\u001b[0m \u001b[0mResolve\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mentry\u001b[0m \u001b[0mpoint\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mits\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mattrs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2263\u001b[0m \"\"\"\n\u001b[0;32m-> 2264\u001b[0;31m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__import__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfromlist\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'__name__'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2265\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2266\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari_env\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAtariEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/atari/atari_env.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0matari_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDependencyNotInstalled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"{}. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mDependencyNotInstalled\u001b[0m: No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym.envs.make(\"Breakout-v0\")\n",
|
||||
"\n",
|
||||
"print(\"Action space size: {}\".format(env.action_space.n))\n",
|
||||
"print(env.get_action_meanings())\n",
|
||||
"\n",
|
||||
"observation = env.reset()\n",
|
||||
"print(\"Observation space shape: {}\".format(observation.shape))\n",
|
||||
"\n",
|
||||
"plt.figure()\n",
|
||||
"plt.imshow(env.render(mode='rgb_array'))\n",
|
||||
"\n",
|
||||
"[env.step(2) for x in range(1)]\n",
|
||||
"plt.figure()\n",
|
||||
"plt.imshow(env.render(mode='rgb_array'))\n",
|
||||
"\n",
|
||||
"env.render(close=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2017-03-08 21:12:44,474] Making new env: CartPole-v0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'base' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-5-3093026983cb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mobservation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobservation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'human'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_close\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/classic_control/cartpole.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscreen_width\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscreen_height\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprefix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Error occured while running `from pyglet.gl import *`\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msuffix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \\\"-screen 0 1400x900x24\\\" python <your_script.py>'\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/pyglet/gl/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcarbon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarbonConfig\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mConfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0;32mdel\u001b[0m \u001b[0mbase\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;31m# XXX remove\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'base' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"env = gym.make('CartPole-v0')\n",
|
||||
"for i_episode in range(20):\n",
|
||||
" observation = env.reset()\n",
|
||||
" for t in range(100):\n",
|
||||
" env.render()\n",
|
||||
" print(observation)\n",
|
||||
" action = env.action_space.sample()\n",
|
||||
" observation, reward, done, info = env.step(action)\n",
|
||||
" if done:\n",
|
||||
" print(\"Episode finished after {} timesteps\".format(t+1))\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"anaconda-cloud": {},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda root]",
|
||||
"language": "python",
|
||||
"name": "conda-root-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
124
tutorials/11 - Deep Q Network/dqn13.py
Normal file
124
tutorials/11 - Deep Q Network/dqn13.py
Normal file
@ -0,0 +1,124 @@
|
||||
%matplotlib inline
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import gym
|
||||
import random
|
||||
import numpy as np
|
||||
import torchvision.transforms as transforms
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.autograd import Variable
|
||||
from collections import deque, namedtuple
|
||||
|
||||
env = gym.envs.make("CartPole-v0")
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(4, 128)
|
||||
self.tanh = nn.Tanh()
|
||||
self.fc2 = nn.Linear(128, 2)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
self.fc1.weight.data.uniform_(-0.1, 0.1)
|
||||
self.fc2.weight.data.uniform_(-0.1, 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.tanh(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def make_epsilon_greedy_policy(network, epsilon, nA):
|
||||
def policy(state):
|
||||
sample = random.random()
|
||||
if sample < (1-epsilon) + (epsilon/nA):
|
||||
q_values = network(state.view(1, -1))
|
||||
action = q_values.data.max(1)[1][0, 0]
|
||||
else:
|
||||
action = random.randrange(nA)
|
||||
return action
|
||||
return policy
|
||||
|
||||
class ReplayMemory(object):
|
||||
|
||||
def __init__(self, capacity):
|
||||
self.memory = deque()
|
||||
self.capacity = capacity
|
||||
|
||||
def push(self, transition):
|
||||
if len(self.memory) > self.capacity:
|
||||
self.memory.popleft()
|
||||
self.memory.append(transition)
|
||||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
def to_tensor(ndarray, volatile=False):
|
||||
return Variable(torch.from_numpy(ndarray), volatile=volatile).float()
|
||||
|
||||
def deep_q_learning(num_episodes=10, batch_size=100,
|
||||
discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):
|
||||
|
||||
# Q-Network and memory
|
||||
net = Net()
|
||||
memory = ReplayMemory(10000)
|
||||
|
||||
# Loss and Optimizer
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
|
||||
|
||||
for i_episode in range(num_episodes):
|
||||
|
||||
# Set policy (TODO: decaying epsilon)
|
||||
#if (i_episode+1) % 100 == 0:
|
||||
# epsilon *= 0.9
|
||||
|
||||
policy = make_epsilon_greedy_policy(
|
||||
net, epsilon, env.action_space.n)
|
||||
|
||||
# Start an episode
|
||||
state = env.reset()
|
||||
|
||||
for t in range(10000):
|
||||
|
||||
# Sample action from epsilon greed policy
|
||||
action = policy(to_tensor(state))
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
|
||||
|
||||
# Restore transition in memory
|
||||
memory.push([state, action, reward, next_state])
|
||||
|
||||
|
||||
if len(memory) >= batch_size:
|
||||
# Sample mini-batch transitions from memory
|
||||
batch = memory.sample(batch_size)
|
||||
state_batch = np.vstack([trans[0] for trans in batch])
|
||||
action_batch =np.vstack([trans[1] for trans in batch])
|
||||
reward_batch = np.vstack([trans[2] for trans in batch])
|
||||
next_state_batch = np.vstack([trans[3] for trans in batch])
|
||||
|
||||
# Forward + Backward + Opimize
|
||||
net.zero_grad()
|
||||
q_values = net(to_tensor(state_batch))
|
||||
next_q_values = net(to_tensor(next_state_batch, volatile=True))
|
||||
next_q_values.volatile = False
|
||||
|
||||
td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]
|
||||
loss = criterion(q_values.gather(1,
|
||||
to_tensor(action_batch).long().view(-1, 1)), td_target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
state = next_state
|
||||
|
||||
if len(memory) >= batch_size and (i_episode+1) % 10 == 0:
|
||||
print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))
|
Reference in New Issue
Block a user