__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:47:25 +05:30
parent eaa248c9e6
commit 02309fa4cc
5 changed files with 9 additions and 9 deletions

View File

@ -112,7 +112,7 @@ class GeneratorResNet(Module):
# Initialize weights to $\mathcal{N}(0, 0.2)$ # Initialize weights to $\mathcal{N}(0, 0.2)$
self.apply(weights_init_normal) self.apply(weights_init_normal)
def __call__(self, x): def forward(self, x):
return self.layers(x) return self.layers(x)
@ -132,7 +132,7 @@ class ResidualBlock(Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x + self.block(x) return x + self.block(x)
@ -184,7 +184,7 @@ class DiscriminatorBlock(Module):
layers.append(nn.LeakyReLU(0.2, inplace=True)) layers.append(nn.LeakyReLU(0.2, inplace=True))
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return self.layers(x) return self.layers(x)

View File

@ -53,7 +53,7 @@ class Generator(Module):
self.apply(_weights_init) self.apply(_weights_init)
def __call__(self, x): def forward(self, x):
# Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]` # Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]`
x = x.unsqueeze(-1).unsqueeze(-1) x = x.unsqueeze(-1).unsqueeze(-1)
x = self.layers(x) x = self.layers(x)

View File

@ -76,7 +76,7 @@ class DiscriminatorLogitsLoss(Module):
self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False) self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False) self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor): def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
""" """
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and `logits_true` are logits from $D(\pmb{x}^{(i)})$ and
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$ `logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
@ -111,7 +111,7 @@ class GeneratorLogitsLoss(Module):
# the above gradient. # the above gradient.
self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False) self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
def __call__(self, logits: torch.Tensor): def forward(self, logits: torch.Tensor):
if len(logits) > len(self.fake_labels): if len(logits) > len(self.fake_labels):
self.register_buffer("fake_labels", self.register_buffer("fake_labels",
_create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False) _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)

View File

@ -101,7 +101,7 @@ class DiscriminatorLoss(Module):
\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$ \frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$
""" """
def __call__(self, f_real: torch.Tensor, f_fake: torch.Tensor): def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):
""" """
* `f_real` is $f_w(x)$ * `f_real` is $f_w(x)$
* `f_fake` is $f_w(g_\theta(z))$ * `f_fake` is $f_w(g_\theta(z))$
@ -127,7 +127,7 @@ class GeneratorLoss(Module):
""" """
def __call__(self, f_fake: torch.Tensor): def forward(self, f_fake: torch.Tensor):
""" """
* `f_fake` is $f_w(g_\theta(z))$ * `f_fake` is $f_w(g_\theta(z))$
""" """

View File

@ -54,7 +54,7 @@ class GradientPenalty(Module):
## Gradient Penalty ## Gradient Penalty
""" """
def __call__(self, x: torch.Tensor, f: torch.Tensor): def forward(self, x: torch.Tensor, f: torch.Tensor):
""" """
* `x` is $x \sim \mathbb{P}_r$ * `x` is $x \sim \mathbb{P}_r$
* `f` is $D(x)$ * `f` is $D(x)$