Used Black to reformat the code in the repository.

This commit is contained in:
Alec Helbling
2023-01-01 23:24:59 -05:00
parent 334662e8c8
commit 3d6e8072e1
71 changed files with 1701 additions and 1135 deletions

View File

@ -12,59 +12,137 @@ import math
sizes of convolutional neural networks
"""
def num2tuple(num):
return num if isinstance(num, tuple) else (num, num)
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
h_w, kernel_size, stride, pad, dilation = (
num2tuple(h_w),
num2tuple(kernel_size),
num2tuple(stride),
num2tuple(pad),
num2tuple(dilation),
)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
h = math.floor(
(h_w[0] + sum(pad[0]) - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1
)
w = math.floor(
(h_w[1] + sum(pad[1]) - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1
)
return h, w
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
def convtransp2d_output_shape(
h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0
):
h_w, kernel_size, stride, pad, dilation, out_pad = (
num2tuple(h_w),
num2tuple(kernel_size),
num2tuple(stride),
num2tuple(pad),
num2tuple(dilation),
num2tuple(out_pad),
)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
h = (
(h_w[0] - 1) * stride[0]
- sum(pad[0])
+ dialation[0] * (kernel_size[0] - 1)
+ out_pad[0]
+ 1
)
w = (
(h_w[1] - 1) * stride[1]
- sum(pad[1])
+ dialation[1] * (kernel_size[1] - 1)
+ out_pad[1]
+ 1
)
return h, w
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
h_w_in, h_w_out, kernel_size, stride, dilation = (
num2tuple(h_w_in),
num2tuple(h_w_out),
num2tuple(kernel_size),
num2tuple(stride),
num2tuple(dilation),
)
p_h = (
(h_w_out[0] - 1) * stride[0]
- h_w_in[0]
+ dilation[0] * (kernel_size[0] - 1)
+ 1
)
p_w = (
(h_w_out[1] - 1) * stride[1]
- h_w_in[1]
+ dilation[1] * (kernel_size[1] - 1)
+ 1
)
return (math.floor(p_h / 2), math.ceil(p_h / 2)), (
math.floor(p_w / 2),
math.ceil(p_w / 2),
)
def convtransp2d_get_padding(
h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0
):
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = (
num2tuple(h_w_in),
num2tuple(h_w_out),
num2tuple(kernel_size),
num2tuple(stride),
num2tuple(dilation),
num2tuple(out_pad),
)
p_h = (
-(
h_w_out[0]
- 1
- out_pad[0]
- dilation[0] * (kernel_size[0] - 1)
- (h_w_in[0] - 1) * stride[0]
)
/ 2
)
p_w = (
-(
h_w_out[1]
- 1
- out_pad[1]
- dilation[1] * (kernel_size[1] - 1)
- (h_w_in[1] - 1) * stride[1]
)
/ 2
)
return (math.floor(p_h / 2), math.ceil(p_h / 2)), (
math.floor(p_w / 2),
math.ceil(p_w / 2),
)
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def load_dataset(train=True, digit=None):
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.Compose([
transforms.Pad(2),
transforms.ToTensor()
])
tensor_transform = transforms.Compose([transforms.Pad(2), transforms.ToTensor()])
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = train,
download = True,
transform = tensor_transform)
dataset = datasets.MNIST(
root="./data", train=train, download=True, transform=tensor_transform
)
# Load specific image
if not digit is None:
idx = dataset.train_labels == digit
@ -73,12 +151,14 @@ def load_dataset(train=True, digit=None):
return dataset
def load_vae_from_path(path, latent_dim):
model = VAE(latent_dim)
model.load_state_dict(torch.load(path))
return model
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class VAE(torch.nn.Module):
@ -96,13 +176,15 @@ class VAE(torch.nn.Module):
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
out_sizes.append(
conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1)
)
inputs = self.d * mul
mul *= 2
self.d_max = inputs
self.last_size = out_sizes[-1][-1]
self.num_linear = self.last_size ** 2 * self.d_max
self.num_linear = self.last_size**2 * self.d_max
# Encoder linear layers
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
@ -112,12 +194,20 @@ class VAE(torch.nn.Module):
mul = inputs // self.d // 2
for i in range(1, self.layer_count):
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
setattr(
self,
"deconv%d" % (i + 1),
nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1),
)
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
inputs = self.d * mul
mul //= 2
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
setattr(
self,
"deconv%d" % (self.layer_count + 1),
nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1),
)
def encode(self, x):
if len(x.shape) < 3:
@ -127,7 +217,11 @@ class VAE(torch.nn.Module):
batch_size = x.shape[0]
for i in range(self.layer_count):
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
x = F.relu(
getattr(self, "conv%d_bn" % (i + 1))(
getattr(self, "conv%d" % (i + 1))(x)
)
)
x = x.view(batch_size, -1)
@ -140,15 +234,20 @@ class VAE(torch.nn.Module):
x = x.view(x.shape[0], self.latent_dim)
x = self.decoder_linear(x)
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
#x = self.deconv1_bn(x)
# x = self.deconv1_bn(x)
x = F.leaky_relu(x, 0.2)
for i in range(1, self.layer_count):
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
x = F.leaky_relu(
getattr(self, "deconv%d_bn" % (i + 1))(
getattr(self, "deconv%d" % (i + 1))(x)
),
0.2,
)
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
x = torch.sigmoid(x)
return x
def forward(self, x):
batch_size = x.shape[0]
mean, logvar = self.encode(x)
@ -157,26 +256,25 @@ class VAE(torch.nn.Module):
reconstructed = self.decode(z)
return mean, logvar, reconstructed, x
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
dataset = load_dataset(train=True, digit=digit)
# DataLoader is used to load the dataset
# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)
# Model Initialization
model = VAE(latent_dim=latent_dim)
# Validation using MSE Loss function
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
kl = torch.mean(
-0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim=1), dim=0
)
recon = torch.nn.functional.mse_loss(reconstructed, original)
# print(f"KL Error {kl}, Recon Error {recon}")
return kl_beta * kl + recon
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
lr = 1e-4,
weight_decay = 0e-8)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0e-8)
outputs = []
losses = []
@ -198,22 +296,24 @@ def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
losses.append(loss.detach().cpu())
outputs.append((epochs, image, reconstructed))
torch.save(model.state_dict(),
torch.save(
model.state_dict(),
os.path.join(
os.environ["PROJECT_ROOT"],
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
)
os.environ["PROJECT_ROOT"],
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth",
),
)
if plot:
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.style.use("fivethirtyeight")
plt.xlabel("Iterations")
plt.ylabel("Loss")
# Plotting the last 100 values
plt.plot(losses)
plt.show()
if __name__ == "__main__":
train_model(latent_dim=2, digit=2, epochs=40)