__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:51:41 +05:30
parent bac2bf85e2
commit 0a4b5b6822
5 changed files with 13 additions and 12 deletions

View File

@ -82,7 +82,7 @@ class Model(Module):
nn.Linear(in_features=256, out_features=4),
)
def __call__(self, obs: torch.Tensor):
def forward(self, obs: torch.Tensor):
# Convolution
h = self.conv(obs)
# Reshape for linear layers