change file name

This commit is contained in:
yunjey
2017-03-26 17:44:09 +09:00
parent ea06532eaa
commit a438f8e6fc
2 changed files with 0 additions and 483 deletions

View File

@@ -1,359 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch DQN Implemenation\n",
"\n",
"<br/>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import gym\n",
"import random\n",
"import numpy as np\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"from torch.autograd import Variable\n",
"from collections import deque, namedtuple"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-03-09 21:31:48,174] Making new env: CartPole-v0\n"
]
}
],
"source": [
"env = gym.envs.make(\"CartPole-v0\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(4, 128)\n",
" self.tanh = nn.Tanh()\n",
" self.fc2 = nn.Linear(128, 2)\n",
" self.init_weights()\n",
" \n",
" def init_weights(self):\n",
" self.fc1.weight.data.uniform_(-0.1, 0.1)\n",
" self.fc2.weight.data.uniform_(-0.1, 0.1)\n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.tanh(out)\n",
" out = self.fc2(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def make_epsilon_greedy_policy(network, epsilon, nA):\n",
" def policy(state):\n",
" sample = random.random()\n",
" if sample < (1-epsilon) + (epsilon/nA):\n",
" q_values = network(state.view(1, -1))\n",
" action = q_values.data.max(1)[1][0, 0]\n",
" else:\n",
" action = random.randrange(nA)\n",
" return action\n",
" return policy"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class ReplayMemory(object):\n",
" \n",
" def __init__(self, capacity):\n",
" self.memory = deque()\n",
" self.capacity = capacity\n",
" \n",
" def push(self, transition):\n",
" if len(self.memory) > self.capacity:\n",
" self.memory.popleft()\n",
" self.memory.append(transition)\n",
" \n",
" def sample(self, batch_size):\n",
" return random.sample(self.memory, batch_size)\n",
" \n",
" def __len__(self):\n",
" return len(self.memory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def to_tensor(ndarray, volatile=False):\n",
" return Variable(torch.from_numpy(ndarray), volatile=volatile).float()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def deep_q_learning(num_episodes=10, batch_size=100, \n",
" discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):\n",
"\n",
" # Q-Network and memory \n",
" net = Net()\n",
" memory = ReplayMemory(10000)\n",
" \n",
" # Loss and Optimizer\n",
" criterion = nn.MSELoss()\n",
" optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
" \n",
" for i_episode in range(num_episodes):\n",
" \n",
" # Set policy (TODO: decaying epsilon)\n",
" #if (i_episode+1) % 100 == 0:\n",
" # epsilon *= 0.9\n",
" \n",
" policy = make_epsilon_greedy_policy(\n",
" net, epsilon, env.action_space.n)\n",
" \n",
" # Start an episode\n",
" state = env.reset()\n",
" \n",
" for t in range(10000):\n",
" \n",
" # Sample action from epsilon greed policy\n",
" action = policy(to_tensor(state)) \n",
" next_state, reward, done, _ = env.step(action)\n",
" \n",
" \n",
" # Restore transition in memory\n",
" memory.push([state, action, reward, next_state])\n",
" \n",
" \n",
" if len(memory) >= batch_size:\n",
" # Sample mini-batch transitions from memory\n",
" batch = memory.sample(batch_size)\n",
" state_batch = np.vstack([trans[0] for trans in batch])\n",
" action_batch =np.vstack([trans[1] for trans in batch]) \n",
" reward_batch = np.vstack([trans[2] for trans in batch])\n",
" next_state_batch = np.vstack([trans[3] for trans in batch])\n",
" \n",
" # Forward + Backward + Opimize\n",
" net.zero_grad()\n",
" q_values = net(to_tensor(state_batch))\n",
" next_q_values = net(to_tensor(next_state_batch, volatile=True))\n",
" next_q_values.volatile = False\n",
" \n",
" td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]\n",
" loss = criterion(q_values.gather(1, \n",
" to_tensor(action_batch).long().view(-1, 1)), td_target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if done:\n",
" break\n",
" \n",
" state = next_state\n",
" \n",
" if len(memory) >= batch_size and (i_episode+1) % 10 == 0:\n",
" print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"episode: 9, time: 9, loss: 0.9945\n",
"episode: 19, time: 9, loss: 1.8221\n",
"episode: 29, time: 9, loss: 4.3124\n",
"episode: 39, time: 8, loss: 6.9764\n",
"episode: 49, time: 9, loss: 6.8300\n",
"episode: 59, time: 8, loss: 5.5186\n",
"episode: 69, time: 9, loss: 4.1160\n",
"episode: 79, time: 9, loss: 2.4802\n",
"episode: 89, time: 13, loss: 0.7890\n",
"episode: 99, time: 10, loss: 0.2805\n",
"episode: 109, time: 12, loss: 0.1323\n",
"episode: 119, time: 13, loss: 0.0519\n",
"episode: 129, time: 18, loss: 0.0176\n",
"episode: 139, time: 22, loss: 0.0067\n",
"episode: 149, time: 17, loss: 0.0114\n",
"episode: 159, time: 26, loss: 0.0017\n",
"episode: 169, time: 23, loss: 0.0018\n",
"episode: 179, time: 21, loss: 0.0023\n",
"episode: 189, time: 11, loss: 0.0024\n",
"episode: 199, time: 7, loss: 0.0040\n",
"episode: 209, time: 8, loss: 0.0030\n",
"episode: 219, time: 7, loss: 0.0070\n",
"episode: 229, time: 9, loss: 0.0031\n",
"episode: 239, time: 9, loss: 0.0029\n",
"episode: 249, time: 8, loss: 0.0046\n",
"episode: 259, time: 8, loss: 0.0009\n",
"episode: 269, time: 10, loss: 0.0020\n",
"episode: 279, time: 9, loss: 0.0025\n",
"episode: 289, time: 8, loss: 0.0015\n",
"episode: 299, time: 10, loss: 0.0009\n",
"episode: 309, time: 8, loss: 0.0012\n",
"episode: 319, time: 8, loss: 0.0034\n",
"episode: 329, time: 8, loss: 0.0008\n",
"episode: 339, time: 9, loss: 0.0021\n",
"episode: 349, time: 8, loss: 0.0018\n",
"episode: 359, time: 9, loss: 0.0017\n",
"episode: 369, time: 9, loss: 0.0006\n",
"episode: 379, time: 9, loss: 0.0023\n",
"episode: 389, time: 10, loss: 0.0017\n",
"episode: 399, time: 8, loss: 0.0018\n",
"episode: 409, time: 8, loss: 0.0023\n",
"episode: 419, time: 9, loss: 0.0020\n",
"episode: 429, time: 9, loss: 0.0006\n",
"episode: 439, time: 10, loss: 0.0006\n",
"episode: 449, time: 10, loss: 0.0025\n",
"episode: 459, time: 9, loss: 0.0013\n",
"episode: 469, time: 8, loss: 0.0011\n",
"episode: 479, time: 8, loss: 0.0005\n",
"episode: 489, time: 8, loss: 0.0004\n",
"episode: 499, time: 7, loss: 0.0017\n",
"episode: 509, time: 7, loss: 0.0004\n",
"episode: 519, time: 10, loss: 0.0008\n",
"episode: 529, time: 11, loss: 0.0006\n",
"episode: 539, time: 9, loss: 0.0010\n",
"episode: 549, time: 8, loss: 0.0006\n",
"episode: 559, time: 8, loss: 0.0012\n",
"episode: 569, time: 9, loss: 0.0011\n",
"episode: 579, time: 8, loss: 0.0010\n",
"episode: 589, time: 8, loss: 0.0008\n",
"episode: 599, time: 10, loss: 0.0010\n",
"episode: 609, time: 8, loss: 0.0005\n",
"episode: 619, time: 9, loss: 0.0004\n",
"episode: 629, time: 8, loss: 0.0007\n",
"episode: 639, time: 10, loss: 0.0014\n",
"episode: 649, time: 10, loss: 0.0004\n",
"episode: 659, time: 9, loss: 0.0008\n",
"episode: 669, time: 8, loss: 0.0005\n",
"episode: 679, time: 8, loss: 0.0003\n",
"episode: 689, time: 9, loss: 0.0009\n",
"episode: 699, time: 8, loss: 0.0004\n",
"episode: 709, time: 8, loss: 0.0013\n",
"episode: 719, time: 8, loss: 0.0006\n",
"episode: 729, time: 7, loss: 0.0021\n",
"episode: 739, time: 9, loss: 0.0023\n",
"episode: 749, time: 9, loss: 0.0039\n",
"episode: 759, time: 8, loss: 0.0030\n",
"episode: 769, time: 9, loss: 0.0016\n",
"episode: 779, time: 7, loss: 0.0041\n",
"episode: 789, time: 8, loss: 0.0050\n",
"episode: 799, time: 8, loss: 0.0041\n",
"episode: 809, time: 11, loss: 0.0053\n",
"episode: 819, time: 7, loss: 0.0018\n",
"episode: 829, time: 9, loss: 0.0019\n",
"episode: 839, time: 11, loss: 0.0017\n",
"episode: 849, time: 8, loss: 0.0029\n",
"episode: 859, time: 9, loss: 0.0012\n",
"episode: 869, time: 9, loss: 0.0036\n",
"episode: 879, time: 7, loss: 0.0017\n",
"episode: 889, time: 9, loss: 0.0016\n",
"episode: 899, time: 10, loss: 0.0023\n",
"episode: 909, time: 8, loss: 0.0032\n",
"episode: 919, time: 8, loss: 0.0015\n",
"episode: 929, time: 9, loss: 0.0021\n",
"episode: 939, time: 9, loss: 0.0015\n",
"episode: 949, time: 9, loss: 0.0016\n",
"episode: 959, time: 9, loss: 0.0013\n",
"episode: 969, time: 12, loss: 0.0029\n",
"episode: 979, time: 7, loss: 0.0016\n",
"episode: 989, time: 7, loss: 0.0012\n",
"episode: 999, time: 9, loss: 0.0013\n"
]
}
],
"source": [
"deep_q_learning(1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@@ -1,124 +0,0 @@
%matplotlib inline
import torch
import torch.nn as nn
import gym
import random
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.autograd import Variable
from collections import deque, namedtuple
env = gym.envs.make("CartPole-v0")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(4, 128)
self.tanh = nn.Tanh()
self.fc2 = nn.Linear(128, 2)
self.init_weights()
def init_weights(self):
self.fc1.weight.data.uniform_(-0.1, 0.1)
self.fc2.weight.data.uniform_(-0.1, 0.1)
def forward(self, x):
out = self.fc1(x)
out = self.tanh(out)
out = self.fc2(out)
return out
def make_epsilon_greedy_policy(network, epsilon, nA):
def policy(state):
sample = random.random()
if sample < (1-epsilon) + (epsilon/nA):
q_values = network(state.view(1, -1))
action = q_values.data.max(1)[1][0, 0]
else:
action = random.randrange(nA)
return action
return policy
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque()
self.capacity = capacity
def push(self, transition):
if len(self.memory) > self.capacity:
self.memory.popleft()
self.memory.append(transition)
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
def to_tensor(ndarray, volatile=False):
return Variable(torch.from_numpy(ndarray), volatile=volatile).float()
def deep_q_learning(num_episodes=10, batch_size=100,
discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):
# Q-Network and memory
net = Net()
memory = ReplayMemory(10000)
# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
for i_episode in range(num_episodes):
# Set policy (TODO: decaying epsilon)
#if (i_episode+1) % 100 == 0:
# epsilon *= 0.9
policy = make_epsilon_greedy_policy(
net, epsilon, env.action_space.n)
# Start an episode
state = env.reset()
for t in range(10000):
# Sample action from epsilon greed policy
action = policy(to_tensor(state))
next_state, reward, done, _ = env.step(action)
# Restore transition in memory
memory.push([state, action, reward, next_state])
if len(memory) >= batch_size:
# Sample mini-batch transitions from memory
batch = memory.sample(batch_size)
state_batch = np.vstack([trans[0] for trans in batch])
action_batch =np.vstack([trans[1] for trans in batch])
reward_batch = np.vstack([trans[2] for trans in batch])
next_state_batch = np.vstack([trans[3] for trans in batch])
# Forward + Backward + Opimize
net.zero_grad()
q_values = net(to_tensor(state_batch))
next_q_values = net(to_tensor(next_state_batch, volatile=True))
next_q_values.volatile = False
td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]
loss = criterion(q_values.gather(1,
to_tensor(action_batch).long().view(-1, 1)), td_target)
loss.backward()
optimizer.step()
if done:
break
state = next_state
if len(memory) >= batch_size and (i_episode+1) % 10 == 0:
print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))