mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-15 07:57:41 +08:00
220 lines
8.4 KiB
Python
220 lines
8.4 KiB
Python
import torch
|
|
from torchvision import datasets
|
|
from torchvision import transforms
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
import math
|
|
|
|
"""
|
|
These are utility functions that help to calculate the input and output
|
|
sizes of convolutional neural networks
|
|
"""
|
|
|
|
def num2tuple(num):
|
|
return num if isinstance(num, tuple) else (num, num)
|
|
|
|
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
|
|
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
|
|
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
|
|
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
|
|
|
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
|
|
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
|
|
|
|
return h, w
|
|
|
|
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
|
|
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
|
|
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
|
|
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
|
|
|
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
|
|
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
|
|
|
|
return h, w
|
|
|
|
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
|
|
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
|
|
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
|
|
|
|
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
|
|
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
|
|
|
|
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
|
|
|
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
|
|
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
|
|
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
|
|
|
|
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
|
|
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
|
|
|
|
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
|
|
|
def load_dataset(train=True, digit=None):
|
|
# Transforms images to a PyTorch Tensor
|
|
tensor_transform = transforms.Compose([
|
|
transforms.Pad(2),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
# Download the MNIST Dataset
|
|
dataset = datasets.MNIST(root = "./data",
|
|
train = train,
|
|
download = True,
|
|
transform = tensor_transform)
|
|
# Load specific image
|
|
if not digit is None:
|
|
idx = dataset.train_labels == digit
|
|
dataset.targets = dataset.targets[idx]
|
|
dataset.data = dataset.data[idx]
|
|
|
|
return dataset
|
|
|
|
def load_vae_from_path(path, latent_dim):
|
|
model = VAE(latent_dim)
|
|
model.load_state_dict(torch.load(path))
|
|
|
|
return model
|
|
|
|
# Creating a PyTorch class
|
|
# 28*28 ==> 9 ==> 28*28
|
|
class VAE(torch.nn.Module):
|
|
def __init__(self, latent_dim=5, layer_count=4, channels=1):
|
|
super().__init__()
|
|
self.latent_dim = latent_dim
|
|
self.in_shape = 32
|
|
self.layer_count = layer_count
|
|
self.channels = channels
|
|
self.d = 128
|
|
mul = 1
|
|
inputs = self.channels
|
|
out_sizes = [(self.in_shape, self.in_shape)]
|
|
for i in range(self.layer_count):
|
|
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
|
|
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
|
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
|
|
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
|
|
inputs = self.d * mul
|
|
mul *= 2
|
|
|
|
self.d_max = inputs
|
|
self.last_size = out_sizes[-1][-1]
|
|
self.num_linear = self.last_size ** 2 * self.d_max
|
|
# Encoder linear layers
|
|
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
|
|
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
|
|
# Decoder linear layer
|
|
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
|
|
|
|
mul = inputs // self.d // 2
|
|
|
|
for i in range(1, self.layer_count):
|
|
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
|
|
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
|
inputs = self.d * mul
|
|
mul //= 2
|
|
|
|
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
|
|
|
|
def encode(self, x):
|
|
if len(x.shape) < 3:
|
|
x = x.unsqueeze(0)
|
|
if len(x.shape) < 4:
|
|
x = x.unsqueeze(1)
|
|
batch_size = x.shape[0]
|
|
|
|
for i in range(self.layer_count):
|
|
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
|
|
|
|
x = x.view(batch_size, -1)
|
|
|
|
mean = self.encoder_mean_linear(x)
|
|
logvar = self.encoder_logvar_linear(x)
|
|
|
|
return mean, logvar
|
|
|
|
def decode(self, x):
|
|
x = x.view(x.shape[0], self.latent_dim)
|
|
x = self.decoder_linear(x)
|
|
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
|
|
#x = self.deconv1_bn(x)
|
|
x = F.leaky_relu(x, 0.2)
|
|
|
|
for i in range(1, self.layer_count):
|
|
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
|
|
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
batch_size = x.shape[0]
|
|
mean, logvar = self.encode(x)
|
|
eps = torch.randn(batch_size, self.latent_dim)
|
|
z = mean + torch.exp(logvar / 2) * eps
|
|
reconstructed = self.decode(z)
|
|
return mean, logvar, reconstructed, x
|
|
|
|
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
|
|
dataset = load_dataset(train=True, digit=digit)
|
|
# DataLoader is used to load the dataset
|
|
# for training
|
|
loader = torch.utils.data.DataLoader(dataset = dataset,
|
|
batch_size = 32,
|
|
shuffle = True)
|
|
# Model Initialization
|
|
model = VAE(latent_dim=latent_dim)
|
|
# Validation using MSE Loss function
|
|
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
|
|
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
|
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
|
# print(f"KL Error {kl}, Recon Error {recon}")
|
|
return kl_beta * kl + recon
|
|
|
|
# Using an Adam Optimizer with lr = 0.1
|
|
optimizer = torch.optim.Adam(model.parameters(),
|
|
lr = 1e-4,
|
|
weight_decay = 0e-8)
|
|
|
|
outputs = []
|
|
losses = []
|
|
for epoch in tqdm(range(epochs)):
|
|
for (image, _) in loader:
|
|
# Output of Autoencoder
|
|
mean, log_var, reconstructed, image = model(image)
|
|
# Calculating the loss function
|
|
loss = loss_function(mean, log_var, reconstructed, image)
|
|
# The gradients are set to zero,
|
|
# the the gradient is computed and stored.
|
|
# .step() performs parameter update
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
# Storing the losses in a list for plotting
|
|
if torch.isnan(loss):
|
|
raise Exception()
|
|
losses.append(loss.detach().cpu())
|
|
outputs.append((epochs, image, reconstructed))
|
|
|
|
torch.save(model.state_dict(),
|
|
os.path.join(
|
|
os.environ["PROJECT_ROOT"],
|
|
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
|
|
)
|
|
)
|
|
|
|
if plot:
|
|
# Defining the Plot Style
|
|
plt.style.use('fivethirtyeight')
|
|
plt.xlabel('Iterations')
|
|
plt.ylabel('Loss')
|
|
|
|
# Plotting the last 100 values
|
|
plt.plot(losses)
|
|
plt.show()
|
|
|
|
if __name__ == "__main__":
|
|
train_model(latent_dim=2, digit=2, epochs=40)
|