mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
__call__ -> forward
This commit is contained in:
@ -112,7 +112,7 @@ class GeneratorResNet(Module):
|
||||
# Initialize weights to $\mathcal{N}(0, 0.2)$
|
||||
self.apply(weights_init_normal)
|
||||
|
||||
def __call__(self, x):
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
@ -132,7 +132,7 @@ class ResidualBlock(Module):
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
@ -184,7 +184,7 @@ class DiscriminatorBlock(Module):
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ class Generator(Module):
|
||||
|
||||
self.apply(_weights_init)
|
||||
|
||||
def __call__(self, x):
|
||||
def forward(self, x):
|
||||
# Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]`
|
||||
x = x.unsqueeze(-1).unsqueeze(-1)
|
||||
x = self.layers(x)
|
||||
|
@ -76,7 +76,7 @@ class DiscriminatorLogitsLoss(Module):
|
||||
self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||
self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
|
||||
|
||||
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
||||
def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
||||
"""
|
||||
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
|
||||
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
|
||||
@ -111,7 +111,7 @@ class GeneratorLogitsLoss(Module):
|
||||
# the above gradient.
|
||||
self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||
|
||||
def __call__(self, logits: torch.Tensor):
|
||||
def forward(self, logits: torch.Tensor):
|
||||
if len(logits) > len(self.fake_labels):
|
||||
self.register_buffer("fake_labels",
|
||||
_create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
|
||||
|
@ -101,7 +101,7 @@ class DiscriminatorLoss(Module):
|
||||
\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$
|
||||
"""
|
||||
|
||||
def __call__(self, f_real: torch.Tensor, f_fake: torch.Tensor):
|
||||
def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):
|
||||
"""
|
||||
* `f_real` is $f_w(x)$
|
||||
* `f_fake` is $f_w(g_\theta(z))$
|
||||
@ -127,7 +127,7 @@ class GeneratorLoss(Module):
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, f_fake: torch.Tensor):
|
||||
def forward(self, f_fake: torch.Tensor):
|
||||
"""
|
||||
* `f_fake` is $f_w(g_\theta(z))$
|
||||
"""
|
||||
|
@ -54,7 +54,7 @@ class GradientPenalty(Module):
|
||||
## Gradient Penalty
|
||||
"""
|
||||
|
||||
def __call__(self, x: torch.Tensor, f: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor, f: torch.Tensor):
|
||||
"""
|
||||
* `x` is $x \sim \mathbb{P}_r$
|
||||
* `f` is $D(x)$
|
||||
|
Reference in New Issue
Block a user