mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 17:57:14 +08:00
🧹 generator input shape
This commit is contained in:
@ -39,11 +39,8 @@ class GeneratorResNet(Module):
|
||||
The generator is a residual network.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape: Tuple[int, int, int], n_residual_blocks: int):
|
||||
def __init__(self, input_channels: int, n_residual_blocks: int):
|
||||
super().__init__()
|
||||
# The number of channels in the input image, which is 3 for RGB images.
|
||||
channels = input_shape[0]
|
||||
|
||||
# This first block runs a $7\times7$ convolution and maps the image to
|
||||
# a feature map.
|
||||
# The output feature map has same height and width because we have
|
||||
@ -53,7 +50,7 @@ class GeneratorResNet(Module):
|
||||
# `inplace=True` in `ReLU` saves a little bit of memory.
|
||||
out_features = 64
|
||||
layers = [
|
||||
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
|
||||
nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
|
||||
nn.InstanceNorm2d(out_features),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
@ -88,7 +85,7 @@ class GeneratorResNet(Module):
|
||||
in_features = out_features
|
||||
|
||||
# Finally we map the feature map to an RGB image
|
||||
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
|
||||
layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
|
||||
|
||||
# Create a sequential module with the layers
|
||||
self.layers = nn.Sequential(*layers)
|
||||
@ -555,8 +552,8 @@ def setup_models(self: Configs):
|
||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||
|
||||
# Create the models
|
||||
self.generator_xy = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
|
||||
self.generator_yx = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
|
||||
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
||||
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
||||
|
||||
@ -680,9 +677,6 @@ def sample():
|
||||
|
||||
# Image transformations
|
||||
transforms_ = [
|
||||
transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
|
||||
transforms.RandomCrop((conf.img_height, conf.img_width)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user