From 0a4b5b68222c74e41c795190c47b85b38c3f0a0d Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 19 Aug 2021 15:51:41 +0530 Subject: [PATCH] __call__ -> forward --- labml_nn/resnet/__init__.py | 8 ++++---- labml_nn/rl/dqn/__init__.py | 6 +++--- labml_nn/rl/dqn/model.py | 2 +- labml_nn/rl/ppo/__init__.py | 7 ++++--- labml_nn/rl/ppo/experiment.py | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/labml_nn/resnet/__init__.py b/labml_nn/resnet/__init__.py index 924e23d5..9752d481 100644 --- a/labml_nn/resnet/__init__.py +++ b/labml_nn/resnet/__init__.py @@ -83,7 +83,7 @@ class ShortcutProjection(Module): # Paper suggests adding batch normalization after each convolution operation self.bn = nn.BatchNorm2d(out_channels) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Convolution and batch normalization return self.bn(self.conv(x)) @@ -140,7 +140,7 @@ class ResidualBlock(Module): # Second activation function (ReLU) (after adding the shortcut) self.act2 = nn.ReLU() - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the input of shape `[batch_size, in_channels, height, width]` """ @@ -221,7 +221,7 @@ class BottleneckResidualBlock(Module): # Second activation function (ReLU) (after adding the shortcut) self.act3 = nn.ReLU() - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the input of shape `[batch_size, in_channels, height, width]` """ @@ -310,7 +310,7 @@ class ResNetBase(Module): # Stack the blocks self.blocks = nn.Sequential(*blocks) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` has shape `[batch_size, img_channels, height, width]` """ diff --git a/labml_nn/rl/dqn/__init__.py b/labml_nn/rl/dqn/__init__.py index 61b99817..91ae9b0d 100644 --- a/labml_nn/rl/dqn/__init__.py +++ b/labml_nn/rl/dqn/__init__.py @@ -106,9 +106,9 @@ class QFuncLoss(Module): self.gamma = gamma self.huber_loss = nn.SmoothL1Loss(reduction='none') - def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor, - target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor, - weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor, + target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor, + weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ * `q` - $Q(s;\theta_i)$ * `action` - $a$ diff --git a/labml_nn/rl/dqn/model.py b/labml_nn/rl/dqn/model.py index 3901fe34..4bace418 100644 --- a/labml_nn/rl/dqn/model.py +++ b/labml_nn/rl/dqn/model.py @@ -82,7 +82,7 @@ class Model(Module): nn.Linear(in_features=256, out_features=4), ) - def __call__(self, obs: torch.Tensor): + def forward(self, obs: torch.Tensor): # Convolution h = self.conv(obs) # Reshape for linear layers diff --git a/labml_nn/rl/ppo/__init__.py b/labml_nn/rl/ppo/__init__.py index 2f119eb3..2fe73115 100644 --- a/labml_nn/rl/ppo/__init__.py +++ b/labml_nn/rl/ppo/__init__.py @@ -136,8 +136,8 @@ class ClippedPPOLoss(Module): def __init__(self): super().__init__() - def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor, - advantage: torch.Tensor, clip: float) -> torch.Tensor: + def forward(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor, + advantage: torch.Tensor, clip: float) -> torch.Tensor: # ratio $r_t(\theta) = \frac{\pi_\theta (a_t|s_t)}{\pi_{\theta_{OLD}} (a_t|s_t)}$; # *this is different from rewards* $r_t$. ratio = torch.exp(log_pi - sampled_log_pi) @@ -200,7 +200,8 @@ class ClippedValueFunctionLoss(Module): significantly from $V_{\theta_{OLD}}$. """ - def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float): + + def forward(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float): clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip) vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2) return 0.5 * vf_loss.mean() diff --git a/labml_nn/rl/ppo/experiment.py b/labml_nn/rl/ppo/experiment.py index f9c2ab55..f35702e1 100644 --- a/labml_nn/rl/ppo/experiment.py +++ b/labml_nn/rl/ppo/experiment.py @@ -69,7 +69,7 @@ class Model(Module): # self.activation = nn.ReLU() - def __call__(self, obs: torch.Tensor): + def forward(self, obs: torch.Tensor): h = self.activation(self.conv1(obs)) h = self.activation(self.conv2(h)) h = self.activation(self.conv3(h))