__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:51:41 +05:30
parent bac2bf85e2
commit 0a4b5b6822
5 changed files with 13 additions and 12 deletions

View File

@ -83,7 +83,7 @@ class ShortcutProjection(Module):
# Paper suggests adding batch normalization after each convolution operation # Paper suggests adding batch normalization after each convolution operation
self.bn = nn.BatchNorm2d(out_channels) self.bn = nn.BatchNorm2d(out_channels)
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# Convolution and batch normalization # Convolution and batch normalization
return self.bn(self.conv(x)) return self.bn(self.conv(x))
@ -140,7 +140,7 @@ class ResidualBlock(Module):
# Second activation function (ReLU) (after adding the shortcut) # Second activation function (ReLU) (after adding the shortcut)
self.act2 = nn.ReLU() self.act2 = nn.ReLU()
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
""" """
* `x` is the input of shape `[batch_size, in_channels, height, width]` * `x` is the input of shape `[batch_size, in_channels, height, width]`
""" """
@ -221,7 +221,7 @@ class BottleneckResidualBlock(Module):
# Second activation function (ReLU) (after adding the shortcut) # Second activation function (ReLU) (after adding the shortcut)
self.act3 = nn.ReLU() self.act3 = nn.ReLU()
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
""" """
* `x` is the input of shape `[batch_size, in_channels, height, width]` * `x` is the input of shape `[batch_size, in_channels, height, width]`
""" """
@ -310,7 +310,7 @@ class ResNetBase(Module):
# Stack the blocks # Stack the blocks
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
""" """
* `x` has shape `[batch_size, img_channels, height, width]` * `x` has shape `[batch_size, img_channels, height, width]`
""" """

View File

@ -106,9 +106,9 @@ class QFuncLoss(Module):
self.gamma = gamma self.gamma = gamma
self.huber_loss = nn.SmoothL1Loss(reduction='none') self.huber_loss = nn.SmoothL1Loss(reduction='none')
def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor, def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor, target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
* `q` - $Q(s;\theta_i)$ * `q` - $Q(s;\theta_i)$
* `action` - $a$ * `action` - $a$

View File

@ -82,7 +82,7 @@ class Model(Module):
nn.Linear(in_features=256, out_features=4), nn.Linear(in_features=256, out_features=4),
) )
def __call__(self, obs: torch.Tensor): def forward(self, obs: torch.Tensor):
# Convolution # Convolution
h = self.conv(obs) h = self.conv(obs)
# Reshape for linear layers # Reshape for linear layers

View File

@ -136,8 +136,8 @@ class ClippedPPOLoss(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor, def forward(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
advantage: torch.Tensor, clip: float) -> torch.Tensor: advantage: torch.Tensor, clip: float) -> torch.Tensor:
# ratio $r_t(\theta) = \frac{\pi_\theta (a_t|s_t)}{\pi_{\theta_{OLD}} (a_t|s_t)}$; # ratio $r_t(\theta) = \frac{\pi_\theta (a_t|s_t)}{\pi_{\theta_{OLD}} (a_t|s_t)}$;
# *this is different from rewards* $r_t$. # *this is different from rewards* $r_t$.
ratio = torch.exp(log_pi - sampled_log_pi) ratio = torch.exp(log_pi - sampled_log_pi)
@ -200,7 +200,8 @@ class ClippedValueFunctionLoss(Module):
significantly from $V_{\theta_{OLD}}$. significantly from $V_{\theta_{OLD}}$.
""" """
def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
def forward(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip) clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2) vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
return 0.5 * vf_loss.mean() return 0.5 * vf_loss.mean()

View File

@ -69,7 +69,7 @@ class Model(Module):
# #
self.activation = nn.ReLU() self.activation = nn.ReLU()
def __call__(self, obs: torch.Tensor): def forward(self, obs: torch.Tensor):
h = self.activation(self.conv1(obs)) h = self.activation(self.conv1(obs))
h = self.activation(self.conv2(h)) h = self.activation(self.conv2(h))
h = self.activation(self.conv3(h)) h = self.activation(self.conv3(h))