mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			100 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			100 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Neural Network Model for Deep Q Network (DQN)
 | |
| """
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| from labml_helpers.module import Module
 | |
| 
 | |
| 
 | |
| class Model(Module):
 | |
|     """
 | |
|     ## Dueling Network ⚔️ Model for $Q$ Values
 | |
| 
 | |
|     We are using a [dueling network](https://arxiv.org/abs/1511.06581)
 | |
|      to calculate Q-values.
 | |
|     Intuition behind dueling network architecture is that in most states
 | |
|      the action doesn't matter,
 | |
|     and in some states the action is significant. Dueling network allows
 | |
|      this to be represented very well.
 | |
| 
 | |
|     \begin{align}
 | |
|         Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
 | |
|         \\
 | |
|         \mathop{\mathbb{E}}_{a \sim \pi(s)}
 | |
|          \Big[
 | |
|           A^\pi(s, a)
 | |
|          \Big]
 | |
|         &= 0
 | |
|     \end{align}
 | |
| 
 | |
|     So we create two networks for $V$ and $A$ and get $Q$ from them.
 | |
|     $$
 | |
|         Q(s, a) = V(s) +
 | |
|         \Big(
 | |
|             A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
 | |
|         \Big)
 | |
|     $$
 | |
|     We share the initial layers of the $V$ and $A$ networks.
 | |
|     """
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
|         self.conv = nn.Sequential(
 | |
|             # The first convolution layer takes a
 | |
|             # $84\times84$ frame and produces a $20\times20$ frame
 | |
|             nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
 | |
|             nn.ReLU(),
 | |
| 
 | |
|             # The second convolution layer takes a
 | |
|             # $20\times20$ frame and produces a $9\times9$ frame
 | |
|             nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
 | |
|             nn.ReLU(),
 | |
| 
 | |
|             # The third convolution layer takes a
 | |
|             # $9\times9$ frame and produces a $7\times7$ frame
 | |
|             nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
 | |
|             nn.ReLU(),
 | |
|         )
 | |
| 
 | |
|         # A fully connected layer takes the flattened
 | |
|         # frame from third convolution layer, and outputs
 | |
|         # $512$ features
 | |
|         self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
 | |
|         self.activation = nn.ReLU()
 | |
| 
 | |
|         # This head gives the state value $V$
 | |
|         self.state_value = nn.Sequential(
 | |
|             nn.Linear(in_features=512, out_features=256),
 | |
|             nn.ReLU(),
 | |
|             nn.Linear(in_features=256, out_features=1),
 | |
|         )
 | |
|         # This head gives the action value $A$
 | |
|         self.action_value = nn.Sequential(
 | |
|             nn.Linear(in_features=512, out_features=256),
 | |
|             nn.ReLU(),
 | |
|             nn.Linear(in_features=256, out_features=4),
 | |
|         )
 | |
| 
 | |
|     def __call__(self, obs: torch.Tensor):
 | |
|         # Convolution
 | |
|         h = self.conv(obs)
 | |
|         # Reshape for linear layers
 | |
|         h = h.reshape((-1, 7 * 7 * 64))
 | |
| 
 | |
|         # Linear layer
 | |
|         h = self.activation(self.lin(h))
 | |
| 
 | |
|         # $A$
 | |
|         action_value = self.action_value(h)
 | |
|         # $V$
 | |
|         state_value = self.state_value(h)
 | |
| 
 | |
|         # $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$
 | |
|         action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
 | |
|         # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
 | |
|         q = state_value + action_score_centered
 | |
| 
 | |
|         return q
 | 
