import torch from torchvision import datasets from torchvision import transforms import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from tqdm import tqdm import math """ These are utility functions that help to calculate the input and output sizes of convolutional neural networks """ def num2tuple(num): return num if isinstance(num, tuple) else (num, num) def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): h_w, kernel_size, stride, pad, dilation = ( num2tuple(h_w), num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), ) pad = num2tuple(pad[0]), num2tuple(pad[1]) h = math.floor( (h_w[0] + sum(pad[0]) - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1 ) w = math.floor( (h_w[1] + sum(pad[1]) - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1 ) return h, w def convtransp2d_output_shape( h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0 ): h_w, kernel_size, stride, pad, dilation, out_pad = ( num2tuple(h_w), num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad), ) pad = num2tuple(pad[0]), num2tuple(pad[1]) h = ( (h_w[0] - 1) * stride[0] - sum(pad[0]) + dialation[0] * (kernel_size[0] - 1) + out_pad[0] + 1 ) w = ( (h_w[1] - 1) * stride[1] - sum(pad[1]) + dialation[1] * (kernel_size[1] - 1) + out_pad[1] + 1 ) return h, w def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1): h_w_in, h_w_out, kernel_size, stride, dilation = ( num2tuple(h_w_in), num2tuple(h_w_out), num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), ) p_h = ( (h_w_out[0] - 1) * stride[0] - h_w_in[0] + dilation[0] * (kernel_size[0] - 1) + 1 ) p_w = ( (h_w_out[1] - 1) * stride[1] - h_w_in[1] + dilation[1] * (kernel_size[1] - 1) + 1 ) return (math.floor(p_h / 2), math.ceil(p_h / 2)), ( math.floor(p_w / 2), math.ceil(p_w / 2), ) def convtransp2d_get_padding( h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0 ): h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = ( num2tuple(h_w_in), num2tuple(h_w_out), num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad), ) p_h = ( -( h_w_out[0] - 1 - out_pad[0] - dilation[0] * (kernel_size[0] - 1) - (h_w_in[0] - 1) * stride[0] ) / 2 ) p_w = ( -( h_w_out[1] - 1 - out_pad[1] - dilation[1] * (kernel_size[1] - 1) - (h_w_in[1] - 1) * stride[1] ) / 2 ) return (math.floor(p_h / 2), math.ceil(p_h / 2)), ( math.floor(p_w / 2), math.ceil(p_w / 2), ) def load_dataset(train=True, digit=None): # Transforms images to a PyTorch Tensor tensor_transform = transforms.Compose([transforms.Pad(2), transforms.ToTensor()]) # Download the MNIST Dataset dataset = datasets.MNIST( root="./data", train=train, download=True, transform=tensor_transform ) # Load specific image if not digit is None: idx = dataset.train_labels == digit dataset.targets = dataset.targets[idx] dataset.data = dataset.data[idx] return dataset def load_vae_from_path(path, latent_dim): model = VAE(latent_dim) model.load_state_dict(torch.load(path)) return model # Creating a PyTorch class # 28*28 ==> 9 ==> 28*28 class VAE(torch.nn.Module): def __init__(self, latent_dim=5, layer_count=4, channels=1): super().__init__() self.latent_dim = latent_dim self.in_shape = 32 self.layer_count = layer_count self.channels = channels self.d = 128 mul = 1 inputs = self.channels out_sizes = [(self.in_shape, self.in_shape)] for i in range(self.layer_count): setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1)) setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul)) h_w = (out_sizes[-1][-1], out_sizes[-1][-1]) out_sizes.append( conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1) ) inputs = self.d * mul mul *= 2 self.d_max = inputs self.last_size = out_sizes[-1][-1] self.num_linear = self.last_size**2 * self.d_max # Encoder linear layers self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim) self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim) # Decoder linear layer self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear) mul = inputs // self.d // 2 for i in range(1, self.layer_count): setattr( self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1), ) setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul)) inputs = self.d * mul mul //= 2 setattr( self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1), ) def encode(self, x): if len(x.shape) < 3: x = x.unsqueeze(0) if len(x.shape) < 4: x = x.unsqueeze(1) batch_size = x.shape[0] for i in range(self.layer_count): x = F.relu( getattr(self, "conv%d_bn" % (i + 1))( getattr(self, "conv%d" % (i + 1))(x) ) ) x = x.view(batch_size, -1) mean = self.encoder_mean_linear(x) logvar = self.encoder_logvar_linear(x) return mean, logvar def decode(self, x): x = x.view(x.shape[0], self.latent_dim) x = self.decoder_linear(x) x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size) # x = self.deconv1_bn(x) x = F.leaky_relu(x, 0.2) for i in range(1, self.layer_count): x = F.leaky_relu( getattr(self, "deconv%d_bn" % (i + 1))( getattr(self, "deconv%d" % (i + 1))(x) ), 0.2, ) x = getattr(self, "deconv%d" % (self.layer_count + 1))(x) x = torch.sigmoid(x) return x def forward(self, x): batch_size = x.shape[0] mean, logvar = self.encode(x) eps = torch.randn(batch_size, self.latent_dim) z = mean + torch.exp(logvar / 2) * eps reconstructed = self.decode(z) return mean, logvar, reconstructed, x def train_model(latent_dim=16, plot=True, digit=1, epochs=200): dataset = load_dataset(train=True, digit=digit) # DataLoader is used to load the dataset # for training loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True) # Model Initialization model = VAE(latent_dim=latent_dim) # Validation using MSE Loss function def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001): kl = torch.mean( -0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim=1), dim=0 ) recon = torch.nn.functional.mse_loss(reconstructed, original) # print(f"KL Error {kl}, Recon Error {recon}") return kl_beta * kl + recon # Using an Adam Optimizer with lr = 0.1 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0e-8) outputs = [] losses = [] for epoch in tqdm(range(epochs)): for (image, _) in loader: # Output of Autoencoder mean, log_var, reconstructed, image = model(image) # Calculating the loss function loss = loss_function(mean, log_var, reconstructed, image) # The gradients are set to zero, # the the gradient is computed and stored. # .step() performs parameter update optimizer.zero_grad() loss.backward() optimizer.step() # Storing the losses in a list for plotting if torch.isnan(loss): raise Exception() losses.append(loss.detach().cpu()) outputs.append((epochs, image, reconstructed)) torch.save( model.state_dict(), os.path.join( os.environ["PROJECT_ROOT"], f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth", ), ) if plot: # Defining the Plot Style plt.style.use("fivethirtyeight") plt.xlabel("Iterations") plt.ylabel("Loss") # Plotting the last 100 values plt.plot(losses) plt.show() if __name__ == "__main__": train_model(latent_dim=2, digit=2, epochs=40)