🧹 generator input shape

This commit is contained in:
Varuna Jayasiri
2020-10-28 11:28:45 +05:30
parent dff635566a
commit ac6864b8fc

View File

@ -39,11 +39,8 @@ class GeneratorResNet(Module):
The generator is a residual network. The generator is a residual network.
""" """
def __init__(self, input_shape: Tuple[int, int, int], n_residual_blocks: int): def __init__(self, input_channels: int, n_residual_blocks: int):
super().__init__() super().__init__()
# The number of channels in the input image, which is 3 for RGB images.
channels = input_shape[0]
# This first block runs a $7\times7$ convolution and maps the image to # This first block runs a $7\times7$ convolution and maps the image to
# a feature map. # a feature map.
# The output feature map has same height and width because we have # The output feature map has same height and width because we have
@ -53,7 +50,7 @@ class GeneratorResNet(Module):
# `inplace=True` in `ReLU` saves a little bit of memory. # `inplace=True` in `ReLU` saves a little bit of memory.
out_features = 64 out_features = 64
layers = [ layers = [
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'), nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
nn.InstanceNorm2d(out_features), nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
@ -88,7 +85,7 @@ class GeneratorResNet(Module):
in_features = out_features in_features = out_features
# Finally we map the feature map to an RGB image # Finally we map the feature map to an RGB image
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()] layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
# Create a sequential module with the layers # Create a sequential module with the layers
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
@ -555,8 +552,8 @@ def setup_models(self: Configs):
input_shape = (self.img_channels, self.img_height, self.img_width) input_shape = (self.img_channels, self.img_height, self.img_width)
# Create the models # Create the models
self.generator_xy = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.generator_yx = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.discriminator_x = Discriminator(input_shape).to(self.device) self.discriminator_x = Discriminator(input_shape).to(self.device)
self.discriminator_y = Discriminator(input_shape).to(self.device) self.discriminator_y = Discriminator(input_shape).to(self.device)
@ -680,9 +677,6 @@ def sample():
# Image transformations # Image transformations
transforms_ = [ transforms_ = [
transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((conf.img_height, conf.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
] ]