mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
__call__ -> forward
This commit is contained in:
@ -112,7 +112,7 @@ class GeneratorResNet(Module):
|
|||||||
# Initialize weights to $\mathcal{N}(0, 0.2)$
|
# Initialize weights to $\mathcal{N}(0, 0.2)$
|
||||||
self.apply(weights_init_normal)
|
self.apply(weights_init_normal)
|
||||||
|
|
||||||
def __call__(self, x):
|
def forward(self, x):
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ class ResidualBlock(Module):
|
|||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return x + self.block(x)
|
return x + self.block(x)
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class DiscriminatorBlock(Module):
|
|||||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class Generator(Module):
|
|||||||
|
|
||||||
self.apply(_weights_init)
|
self.apply(_weights_init)
|
||||||
|
|
||||||
def __call__(self, x):
|
def forward(self, x):
|
||||||
# Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]`
|
# Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]`
|
||||||
x = x.unsqueeze(-1).unsqueeze(-1)
|
x = x.unsqueeze(-1).unsqueeze(-1)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
@ -76,7 +76,7 @@ class DiscriminatorLogitsLoss(Module):
|
|||||||
self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
|
self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||||
self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
|
self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
|
||||||
|
|
||||||
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
|
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
|
||||||
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
|
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
|
||||||
@ -111,7 +111,7 @@ class GeneratorLogitsLoss(Module):
|
|||||||
# the above gradient.
|
# the above gradient.
|
||||||
self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
|
self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor):
|
def forward(self, logits: torch.Tensor):
|
||||||
if len(logits) > len(self.fake_labels):
|
if len(logits) > len(self.fake_labels):
|
||||||
self.register_buffer("fake_labels",
|
self.register_buffer("fake_labels",
|
||||||
_create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
|
_create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
|
||||||
|
@ -101,7 +101,7 @@ class DiscriminatorLoss(Module):
|
|||||||
\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$
|
\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, f_real: torch.Tensor, f_fake: torch.Tensor):
|
def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `f_real` is $f_w(x)$
|
* `f_real` is $f_w(x)$
|
||||||
* `f_fake` is $f_w(g_\theta(z))$
|
* `f_fake` is $f_w(g_\theta(z))$
|
||||||
@ -127,7 +127,7 @@ class GeneratorLoss(Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, f_fake: torch.Tensor):
|
def forward(self, f_fake: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `f_fake` is $f_w(g_\theta(z))$
|
* `f_fake` is $f_w(g_\theta(z))$
|
||||||
"""
|
"""
|
||||||
|
@ -54,7 +54,7 @@ class GradientPenalty(Module):
|
|||||||
## Gradient Penalty
|
## Gradient Penalty
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, f: torch.Tensor):
|
def forward(self, x: torch.Tensor, f: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `x` is $x \sim \mathbb{P}_r$
|
* `x` is $x \sim \mathbb{P}_r$
|
||||||
* `f` is $D(x)$
|
* `f` is $D(x)$
|
||||||
|
Reference in New Issue
Block a user