mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-07 16:50:09 +08:00
Used Black to reformat the code in the repository.
This commit is contained in:
@ -12,59 +12,137 @@ import math
|
||||
sizes of convolutional neural networks
|
||||
"""
|
||||
|
||||
|
||||
def num2tuple(num):
|
||||
return num if isinstance(num, tuple) else (num, num)
|
||||
|
||||
|
||||
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
|
||||
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
|
||||
h_w, kernel_size, stride, pad, dilation = (
|
||||
num2tuple(h_w),
|
||||
num2tuple(kernel_size),
|
||||
num2tuple(stride),
|
||||
num2tuple(pad),
|
||||
num2tuple(dilation),
|
||||
)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
|
||||
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
|
||||
|
||||
|
||||
h = math.floor(
|
||||
(h_w[0] + sum(pad[0]) - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1
|
||||
)
|
||||
w = math.floor(
|
||||
(h_w[1] + sum(pad[1]) - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1
|
||||
)
|
||||
|
||||
return h, w
|
||||
|
||||
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
|
||||
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
|
||||
|
||||
def convtransp2d_output_shape(
|
||||
h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0
|
||||
):
|
||||
h_w, kernel_size, stride, pad, dilation, out_pad = (
|
||||
num2tuple(h_w),
|
||||
num2tuple(kernel_size),
|
||||
num2tuple(stride),
|
||||
num2tuple(pad),
|
||||
num2tuple(dilation),
|
||||
num2tuple(out_pad),
|
||||
)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
|
||||
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
|
||||
|
||||
|
||||
h = (
|
||||
(h_w[0] - 1) * stride[0]
|
||||
- sum(pad[0])
|
||||
+ dialation[0] * (kernel_size[0] - 1)
|
||||
+ out_pad[0]
|
||||
+ 1
|
||||
)
|
||||
w = (
|
||||
(h_w[1] - 1) * stride[1]
|
||||
- sum(pad[1])
|
||||
+ dialation[1] * (kernel_size[1] - 1)
|
||||
+ out_pad[1]
|
||||
+ 1
|
||||
)
|
||||
|
||||
return h, w
|
||||
|
||||
|
||||
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
|
||||
|
||||
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
|
||||
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation = (
|
||||
num2tuple(h_w_in),
|
||||
num2tuple(h_w_out),
|
||||
num2tuple(kernel_size),
|
||||
num2tuple(stride),
|
||||
num2tuple(dilation),
|
||||
)
|
||||
|
||||
p_h = (
|
||||
(h_w_out[0] - 1) * stride[0]
|
||||
- h_w_in[0]
|
||||
+ dilation[0] * (kernel_size[0] - 1)
|
||||
+ 1
|
||||
)
|
||||
p_w = (
|
||||
(h_w_out[1] - 1) * stride[1]
|
||||
- h_w_in[1]
|
||||
+ dilation[1] * (kernel_size[1] - 1)
|
||||
+ 1
|
||||
)
|
||||
|
||||
return (math.floor(p_h / 2), math.ceil(p_h / 2)), (
|
||||
math.floor(p_w / 2),
|
||||
math.ceil(p_w / 2),
|
||||
)
|
||||
|
||||
|
||||
def convtransp2d_get_padding(
|
||||
h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0
|
||||
):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = (
|
||||
num2tuple(h_w_in),
|
||||
num2tuple(h_w_out),
|
||||
num2tuple(kernel_size),
|
||||
num2tuple(stride),
|
||||
num2tuple(dilation),
|
||||
num2tuple(out_pad),
|
||||
)
|
||||
|
||||
p_h = (
|
||||
-(
|
||||
h_w_out[0]
|
||||
- 1
|
||||
- out_pad[0]
|
||||
- dilation[0] * (kernel_size[0] - 1)
|
||||
- (h_w_in[0] - 1) * stride[0]
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
p_w = (
|
||||
-(
|
||||
h_w_out[1]
|
||||
- 1
|
||||
- out_pad[1]
|
||||
- dilation[1] * (kernel_size[1] - 1)
|
||||
- (h_w_in[1] - 1) * stride[1]
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
|
||||
return (math.floor(p_h / 2), math.ceil(p_h / 2)), (
|
||||
math.floor(p_w / 2),
|
||||
math.ceil(p_w / 2),
|
||||
)
|
||||
|
||||
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
|
||||
|
||||
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
|
||||
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def load_dataset(train=True, digit=None):
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.Compose([
|
||||
transforms.Pad(2),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
tensor_transform = transforms.Compose([transforms.Pad(2), transforms.ToTensor()])
|
||||
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = train,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
dataset = datasets.MNIST(
|
||||
root="./data", train=train, download=True, transform=tensor_transform
|
||||
)
|
||||
# Load specific image
|
||||
if not digit is None:
|
||||
idx = dataset.train_labels == digit
|
||||
@ -73,12 +151,14 @@ def load_dataset(train=True, digit=None):
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_vae_from_path(path, latent_dim):
|
||||
model = VAE(latent_dim)
|
||||
model.load_state_dict(torch.load(path))
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Creating a PyTorch class
|
||||
# 28*28 ==> 9 ==> 28*28
|
||||
class VAE(torch.nn.Module):
|
||||
@ -96,13 +176,15 @@ class VAE(torch.nn.Module):
|
||||
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
|
||||
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
|
||||
out_sizes.append(
|
||||
conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1)
|
||||
)
|
||||
inputs = self.d * mul
|
||||
mul *= 2
|
||||
|
||||
self.d_max = inputs
|
||||
self.last_size = out_sizes[-1][-1]
|
||||
self.num_linear = self.last_size ** 2 * self.d_max
|
||||
self.num_linear = self.last_size**2 * self.d_max
|
||||
# Encoder linear layers
|
||||
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
@ -112,12 +194,20 @@ class VAE(torch.nn.Module):
|
||||
mul = inputs // self.d // 2
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(
|
||||
self,
|
||||
"deconv%d" % (i + 1),
|
||||
nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1),
|
||||
)
|
||||
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
inputs = self.d * mul
|
||||
mul //= 2
|
||||
|
||||
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
|
||||
setattr(
|
||||
self,
|
||||
"deconv%d" % (self.layer_count + 1),
|
||||
nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
if len(x.shape) < 3:
|
||||
@ -127,7 +217,11 @@ class VAE(torch.nn.Module):
|
||||
batch_size = x.shape[0]
|
||||
|
||||
for i in range(self.layer_count):
|
||||
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
|
||||
x = F.relu(
|
||||
getattr(self, "conv%d_bn" % (i + 1))(
|
||||
getattr(self, "conv%d" % (i + 1))(x)
|
||||
)
|
||||
)
|
||||
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
@ -140,15 +234,20 @@ class VAE(torch.nn.Module):
|
||||
x = x.view(x.shape[0], self.latent_dim)
|
||||
x = self.decoder_linear(x)
|
||||
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
|
||||
#x = self.deconv1_bn(x)
|
||||
# x = self.deconv1_bn(x)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
|
||||
x = F.leaky_relu(
|
||||
getattr(self, "deconv%d_bn" % (i + 1))(
|
||||
getattr(self, "deconv%d" % (i + 1))(x)
|
||||
),
|
||||
0.2,
|
||||
)
|
||||
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
mean, logvar = self.encode(x)
|
||||
@ -157,26 +256,25 @@ class VAE(torch.nn.Module):
|
||||
reconstructed = self.decode(z)
|
||||
return mean, logvar, reconstructed, x
|
||||
|
||||
|
||||
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
|
||||
dataset = load_dataset(train=True, digit=digit)
|
||||
# DataLoader is used to load the dataset
|
||||
# DataLoader is used to load the dataset
|
||||
# for training
|
||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
||||
batch_size = 32,
|
||||
shuffle = True)
|
||||
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)
|
||||
# Model Initialization
|
||||
model = VAE(latent_dim=latent_dim)
|
||||
# Validation using MSE Loss function
|
||||
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
|
||||
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
||||
kl = torch.mean(
|
||||
-0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim=1), dim=0
|
||||
)
|
||||
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
||||
# print(f"KL Error {kl}, Recon Error {recon}")
|
||||
return kl_beta * kl + recon
|
||||
|
||||
# Using an Adam Optimizer with lr = 0.1
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
lr = 1e-4,
|
||||
weight_decay = 0e-8)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0e-8)
|
||||
|
||||
outputs = []
|
||||
losses = []
|
||||
@ -198,22 +296,24 @@ def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
|
||||
losses.append(loss.detach().cpu())
|
||||
outputs.append((epochs, image, reconstructed))
|
||||
|
||||
torch.save(model.state_dict(),
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
os.path.join(
|
||||
os.environ["PROJECT_ROOT"],
|
||||
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
|
||||
)
|
||||
os.environ["PROJECT_ROOT"],
|
||||
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth",
|
||||
),
|
||||
)
|
||||
|
||||
if plot:
|
||||
# Defining the Plot Style
|
||||
plt.style.use('fivethirtyeight')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
plt.style.use("fivethirtyeight")
|
||||
plt.xlabel("Iterations")
|
||||
plt.ylabel("Loss")
|
||||
|
||||
# Plotting the last 100 values
|
||||
plt.plot(losses)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_model(latent_dim=2, digit=2, epochs=40)
|
||||
|
Reference in New Issue
Block a user