This commit is contained in:
yunjey
2017-03-11 15:16:41 +09:00
parent cd4f2d52f1
commit 440a18d146
3 changed files with 637 additions and 0 deletions

View File

@ -0,0 +1,359 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch DQN Implemenation\n",
"\n",
"<br/>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import gym\n",
"import random\n",
"import numpy as np\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"from torch.autograd import Variable\n",
"from collections import deque, namedtuple"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-03-09 21:31:48,174] Making new env: CartPole-v0\n"
]
}
],
"source": [
"env = gym.envs.make(\"CartPole-v0\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(4, 128)\n",
" self.tanh = nn.Tanh()\n",
" self.fc2 = nn.Linear(128, 2)\n",
" self.init_weights()\n",
" \n",
" def init_weights(self):\n",
" self.fc1.weight.data.uniform_(-0.1, 0.1)\n",
" self.fc2.weight.data.uniform_(-0.1, 0.1)\n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.tanh(out)\n",
" out = self.fc2(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def make_epsilon_greedy_policy(network, epsilon, nA):\n",
" def policy(state):\n",
" sample = random.random()\n",
" if sample < (1-epsilon) + (epsilon/nA):\n",
" q_values = network(state.view(1, -1))\n",
" action = q_values.data.max(1)[1][0, 0]\n",
" else:\n",
" action = random.randrange(nA)\n",
" return action\n",
" return policy"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class ReplayMemory(object):\n",
" \n",
" def __init__(self, capacity):\n",
" self.memory = deque()\n",
" self.capacity = capacity\n",
" \n",
" def push(self, transition):\n",
" if len(self.memory) > self.capacity:\n",
" self.memory.popleft()\n",
" self.memory.append(transition)\n",
" \n",
" def sample(self, batch_size):\n",
" return random.sample(self.memory, batch_size)\n",
" \n",
" def __len__(self):\n",
" return len(self.memory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def to_tensor(ndarray, volatile=False):\n",
" return Variable(torch.from_numpy(ndarray), volatile=volatile).float()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def deep_q_learning(num_episodes=10, batch_size=100, \n",
" discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):\n",
"\n",
" # Q-Network and memory \n",
" net = Net()\n",
" memory = ReplayMemory(10000)\n",
" \n",
" # Loss and Optimizer\n",
" criterion = nn.MSELoss()\n",
" optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
" \n",
" for i_episode in range(num_episodes):\n",
" \n",
" # Set policy (TODO: decaying epsilon)\n",
" #if (i_episode+1) % 100 == 0:\n",
" # epsilon *= 0.9\n",
" \n",
" policy = make_epsilon_greedy_policy(\n",
" net, epsilon, env.action_space.n)\n",
" \n",
" # Start an episode\n",
" state = env.reset()\n",
" \n",
" for t in range(10000):\n",
" \n",
" # Sample action from epsilon greed policy\n",
" action = policy(to_tensor(state)) \n",
" next_state, reward, done, _ = env.step(action)\n",
" \n",
" \n",
" # Restore transition in memory\n",
" memory.push([state, action, reward, next_state])\n",
" \n",
" \n",
" if len(memory) >= batch_size:\n",
" # Sample mini-batch transitions from memory\n",
" batch = memory.sample(batch_size)\n",
" state_batch = np.vstack([trans[0] for trans in batch])\n",
" action_batch =np.vstack([trans[1] for trans in batch]) \n",
" reward_batch = np.vstack([trans[2] for trans in batch])\n",
" next_state_batch = np.vstack([trans[3] for trans in batch])\n",
" \n",
" # Forward + Backward + Opimize\n",
" net.zero_grad()\n",
" q_values = net(to_tensor(state_batch))\n",
" next_q_values = net(to_tensor(next_state_batch, volatile=True))\n",
" next_q_values.volatile = False\n",
" \n",
" td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]\n",
" loss = criterion(q_values.gather(1, \n",
" to_tensor(action_batch).long().view(-1, 1)), td_target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if done:\n",
" break\n",
" \n",
" state = next_state\n",
" \n",
" if len(memory) >= batch_size and (i_episode+1) % 10 == 0:\n",
" print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"episode: 9, time: 9, loss: 0.9945\n",
"episode: 19, time: 9, loss: 1.8221\n",
"episode: 29, time: 9, loss: 4.3124\n",
"episode: 39, time: 8, loss: 6.9764\n",
"episode: 49, time: 9, loss: 6.8300\n",
"episode: 59, time: 8, loss: 5.5186\n",
"episode: 69, time: 9, loss: 4.1160\n",
"episode: 79, time: 9, loss: 2.4802\n",
"episode: 89, time: 13, loss: 0.7890\n",
"episode: 99, time: 10, loss: 0.2805\n",
"episode: 109, time: 12, loss: 0.1323\n",
"episode: 119, time: 13, loss: 0.0519\n",
"episode: 129, time: 18, loss: 0.0176\n",
"episode: 139, time: 22, loss: 0.0067\n",
"episode: 149, time: 17, loss: 0.0114\n",
"episode: 159, time: 26, loss: 0.0017\n",
"episode: 169, time: 23, loss: 0.0018\n",
"episode: 179, time: 21, loss: 0.0023\n",
"episode: 189, time: 11, loss: 0.0024\n",
"episode: 199, time: 7, loss: 0.0040\n",
"episode: 209, time: 8, loss: 0.0030\n",
"episode: 219, time: 7, loss: 0.0070\n",
"episode: 229, time: 9, loss: 0.0031\n",
"episode: 239, time: 9, loss: 0.0029\n",
"episode: 249, time: 8, loss: 0.0046\n",
"episode: 259, time: 8, loss: 0.0009\n",
"episode: 269, time: 10, loss: 0.0020\n",
"episode: 279, time: 9, loss: 0.0025\n",
"episode: 289, time: 8, loss: 0.0015\n",
"episode: 299, time: 10, loss: 0.0009\n",
"episode: 309, time: 8, loss: 0.0012\n",
"episode: 319, time: 8, loss: 0.0034\n",
"episode: 329, time: 8, loss: 0.0008\n",
"episode: 339, time: 9, loss: 0.0021\n",
"episode: 349, time: 8, loss: 0.0018\n",
"episode: 359, time: 9, loss: 0.0017\n",
"episode: 369, time: 9, loss: 0.0006\n",
"episode: 379, time: 9, loss: 0.0023\n",
"episode: 389, time: 10, loss: 0.0017\n",
"episode: 399, time: 8, loss: 0.0018\n",
"episode: 409, time: 8, loss: 0.0023\n",
"episode: 419, time: 9, loss: 0.0020\n",
"episode: 429, time: 9, loss: 0.0006\n",
"episode: 439, time: 10, loss: 0.0006\n",
"episode: 449, time: 10, loss: 0.0025\n",
"episode: 459, time: 9, loss: 0.0013\n",
"episode: 469, time: 8, loss: 0.0011\n",
"episode: 479, time: 8, loss: 0.0005\n",
"episode: 489, time: 8, loss: 0.0004\n",
"episode: 499, time: 7, loss: 0.0017\n",
"episode: 509, time: 7, loss: 0.0004\n",
"episode: 519, time: 10, loss: 0.0008\n",
"episode: 529, time: 11, loss: 0.0006\n",
"episode: 539, time: 9, loss: 0.0010\n",
"episode: 549, time: 8, loss: 0.0006\n",
"episode: 559, time: 8, loss: 0.0012\n",
"episode: 569, time: 9, loss: 0.0011\n",
"episode: 579, time: 8, loss: 0.0010\n",
"episode: 589, time: 8, loss: 0.0008\n",
"episode: 599, time: 10, loss: 0.0010\n",
"episode: 609, time: 8, loss: 0.0005\n",
"episode: 619, time: 9, loss: 0.0004\n",
"episode: 629, time: 8, loss: 0.0007\n",
"episode: 639, time: 10, loss: 0.0014\n",
"episode: 649, time: 10, loss: 0.0004\n",
"episode: 659, time: 9, loss: 0.0008\n",
"episode: 669, time: 8, loss: 0.0005\n",
"episode: 679, time: 8, loss: 0.0003\n",
"episode: 689, time: 9, loss: 0.0009\n",
"episode: 699, time: 8, loss: 0.0004\n",
"episode: 709, time: 8, loss: 0.0013\n",
"episode: 719, time: 8, loss: 0.0006\n",
"episode: 729, time: 7, loss: 0.0021\n",
"episode: 739, time: 9, loss: 0.0023\n",
"episode: 749, time: 9, loss: 0.0039\n",
"episode: 759, time: 8, loss: 0.0030\n",
"episode: 769, time: 9, loss: 0.0016\n",
"episode: 779, time: 7, loss: 0.0041\n",
"episode: 789, time: 8, loss: 0.0050\n",
"episode: 799, time: 8, loss: 0.0041\n",
"episode: 809, time: 11, loss: 0.0053\n",
"episode: 819, time: 7, loss: 0.0018\n",
"episode: 829, time: 9, loss: 0.0019\n",
"episode: 839, time: 11, loss: 0.0017\n",
"episode: 849, time: 8, loss: 0.0029\n",
"episode: 859, time: 9, loss: 0.0012\n",
"episode: 869, time: 9, loss: 0.0036\n",
"episode: 879, time: 7, loss: 0.0017\n",
"episode: 889, time: 9, loss: 0.0016\n",
"episode: 899, time: 10, loss: 0.0023\n",
"episode: 909, time: 8, loss: 0.0032\n",
"episode: 919, time: 8, loss: 0.0015\n",
"episode: 929, time: 9, loss: 0.0021\n",
"episode: 939, time: 9, loss: 0.0015\n",
"episode: 949, time: 9, loss: 0.0016\n",
"episode: 959, time: 9, loss: 0.0013\n",
"episode: 969, time: 12, loss: 0.0029\n",
"episode: 979, time: 7, loss: 0.0016\n",
"episode: 989, time: 7, loss: 0.0012\n",
"episode: 999, time: 9, loss: 0.0013\n"
]
}
],
"source": [
"deep_q_learning(1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}