mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-29 18:39:39 +08:00
Got VAE forward pass animation finished. TODO Interpolation animation
This commit is contained in:

committed by
Alec Helbling

parent
8140aec3be
commit
fe7089abbf
2
Makefile
2
Makefile
@ -1,6 +1,6 @@
|
|||||||
video:
|
video:
|
||||||
manim -pqh src/vae.py VAEScene --media_dir media
|
manim -pqh src/vae.py VAEScene --media_dir media
|
||||||
cp media/videos/vae/720p60/VAEScene.mp4 final_videos
|
cp media/videos/vae/1080p60/VAEScene.mp4 final_videos
|
||||||
train:
|
train:
|
||||||
cd src/autoencoder_models
|
cd src/autoencoder_models
|
||||||
python vanilla_autoencoder.py
|
python vanilla_autoencoder.py
|
||||||
|
Binary file not shown.
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/autoencoder_models/__init__.py
Normal file
0
src/autoencoder_models/__init__.py
Normal file
41
src/autoencoder_models/generate_images.py
Normal file
41
src/autoencoder_models/generate_images.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from variational_autoencoder import VAE
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torchvision import datasets
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
vae = VAE(latent_dim=16)
|
||||||
|
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||||
|
# Transforms images to a PyTorch Tensor
|
||||||
|
tensor_transform = transforms.ToTensor()
|
||||||
|
# Download the MNIST Dataset
|
||||||
|
dataset = datasets.MNIST(root = "./data",
|
||||||
|
train = True,
|
||||||
|
download = True,
|
||||||
|
transform = tensor_transform)
|
||||||
|
# Generate reconstructions
|
||||||
|
num_recons = 10
|
||||||
|
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
|
||||||
|
image_pairs = []
|
||||||
|
for i in range(num_recons):
|
||||||
|
base_image, _ = dataset[i]
|
||||||
|
base_image = base_image.reshape(-1, 28*28)
|
||||||
|
_, _, recon_image, _ = vae.forward(base_image)
|
||||||
|
base_image = base_image.detach().numpy()
|
||||||
|
base_image = np.reshape(base_image, (28, 28)) * 255
|
||||||
|
recon_image = recon_image.detach().numpy()
|
||||||
|
recon_image = np.reshape(recon_image, (28, 28)) * 255
|
||||||
|
# Add to plot
|
||||||
|
axs[i][0].imshow(base_image)
|
||||||
|
axs[i][1].imshow(recon_image)
|
||||||
|
# image pairs
|
||||||
|
image_pairs.append((base_image, recon_image))
|
||||||
|
|
||||||
|
with open("image_pairs.pkl", "wb") as f:
|
||||||
|
pickle.dump(image_pairs, f)
|
||||||
|
|
||||||
|
plt.show()
|
56
src/autoencoder_models/generate_interpolation.py
Normal file
56
src/autoencoder_models/generate_interpolation.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from variational_autoencoder import VAE
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torchvision import datasets
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
vae = VAE(latent_dim=16)
|
||||||
|
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||||
|
# Transforms images to a PyTorch Tensor
|
||||||
|
tensor_transform = transforms.ToTensor()
|
||||||
|
# Download the MNIST Dataset
|
||||||
|
dataset = datasets.MNIST(root = "./data",
|
||||||
|
train = True,
|
||||||
|
download = True,
|
||||||
|
transform = tensor_transform)
|
||||||
|
# Generate reconstructions
|
||||||
|
num_images = 50
|
||||||
|
image_pairs = []
|
||||||
|
save_object = {"interpolation_path":[], "interpolation_images":[]}
|
||||||
|
|
||||||
|
# Make interpolation path
|
||||||
|
image_a, image_b = dataset[0][0], dataset[1][0]
|
||||||
|
image_a = image_a.view(28*28)
|
||||||
|
image_b = image_b.view(28*28)
|
||||||
|
z_a, _, _, _ = vae.forward(image_a)
|
||||||
|
z_a = z_a.detach().cpu().numpy()
|
||||||
|
z_b, _, _, _ = vae.forward(image_b)
|
||||||
|
z_b = z_b.detach().cpu().numpy()
|
||||||
|
interpolation_path = np.linspace(z_a, z_b, num=num_images)
|
||||||
|
# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images)
|
||||||
|
save_object["interpolation_path"] = interpolation_path
|
||||||
|
|
||||||
|
for i in range(num_images):
|
||||||
|
# Generate
|
||||||
|
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
||||||
|
gen_image = vae.decode(z)
|
||||||
|
save_object["interpolation_images"].append(gen_image)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
||||||
|
image_pairs = []
|
||||||
|
for i in range(num_images):
|
||||||
|
im = save_object["interpolation_images"][i]
|
||||||
|
im = im.detach().numpy()
|
||||||
|
recon_image = np.reshape(im, (28, 28)) * 255
|
||||||
|
# Add to plot
|
||||||
|
axs[i].imshow(recon_image)
|
||||||
|
|
||||||
|
# Perform intrpolations
|
||||||
|
with open("interpolations.pkl", "wb") as f:
|
||||||
|
pickle.dump(save_object, f)
|
||||||
|
|
||||||
|
plt.show()
|
BIN
src/autoencoder_models/image_pairs.pkl
Normal file
BIN
src/autoencoder_models/image_pairs.pkl
Normal file
Binary file not shown.
BIN
src/autoencoder_models/interpolations.pkl
Normal file
BIN
src/autoencoder_models/interpolations.pkl
Normal file
Binary file not shown.
BIN
src/autoencoder_models/saved_models/model.pth
Normal file
BIN
src/autoencoder_models/saved_models/model.pth
Normal file
Binary file not shown.
@ -1,112 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torchvision import datasets
|
|
||||||
from torchvision import transforms
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# Transforms images to a PyTorch Tensor
|
|
||||||
tensor_transform = transforms.ToTensor()
|
|
||||||
|
|
||||||
# Download the MNIST Dataset
|
|
||||||
dataset = datasets.MNIST(root = "./data",
|
|
||||||
train = True,
|
|
||||||
download = True,
|
|
||||||
transform = tensor_transform)
|
|
||||||
|
|
||||||
# DataLoader is used to load the dataset
|
|
||||||
# for training
|
|
||||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
|
||||||
batch_size = 32,
|
|
||||||
shuffle = True)
|
|
||||||
|
|
||||||
# Creating a PyTorch class
|
|
||||||
# 28*28 ==> 9 ==> 28*28
|
|
||||||
class VAE(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
# Building an linear encoder with Linear
|
|
||||||
# layer followed by Relu activation function
|
|
||||||
# 784 ==> 9
|
|
||||||
self.encoder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(28 * 28, 128),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(128, 64),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(64, 36),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(36, 18),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
)
|
|
||||||
self.mean_embedding = torch.nn.Linear(18, 9)
|
|
||||||
self.logvar_embedding = torch.nn.Linear(18, 9)
|
|
||||||
|
|
||||||
# Building an linear decoder with Linear
|
|
||||||
# layer followed by Relu activation function
|
|
||||||
# The Sigmoid activation function
|
|
||||||
# outputs the value between 0 and 1
|
|
||||||
# 9 ==> 784
|
|
||||||
self.decoder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(9, 18),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(18, 36),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(36, 64),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(64, 128),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(128, 28 * 28),
|
|
||||||
torch.nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
encoded = self.encoder(x)
|
|
||||||
mean = self.mean_embedding(encoded)
|
|
||||||
logvar = self.logvar_embedding(encoded)
|
|
||||||
combined = torch.cat((mean, logvar), dim=1)
|
|
||||||
reconstructed = self.decoder(combined)
|
|
||||||
return mean, logvar, reconstructed, x
|
|
||||||
|
|
||||||
# Model Initialization
|
|
||||||
model = VAE()
|
|
||||||
# Validation using MSE Loss function
|
|
||||||
def loss_function(mean, log_var, reconstructed, original):
|
|
||||||
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
|
||||||
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
|
||||||
|
|
||||||
return kl + recon
|
|
||||||
|
|
||||||
# Using an Adam Optimizer with lr = 0.1
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(),
|
|
||||||
lr = 1e-1,
|
|
||||||
weight_decay = 1e-8)
|
|
||||||
|
|
||||||
epochs = 10
|
|
||||||
outputs = []
|
|
||||||
losses = []
|
|
||||||
for epoch in tqdm(range(epochs)):
|
|
||||||
for (image, _) in loader:
|
|
||||||
# Reshaping the image to (-1, 784)
|
|
||||||
image = image.reshape(-1, 28*28)
|
|
||||||
# Output of Autoencoder
|
|
||||||
mean, log_var, reconstructed, image = model(image)
|
|
||||||
# Calculating the loss function
|
|
||||||
loss = loss_function(mean, log_var, reconstructed, image)
|
|
||||||
# The gradients are set to zero,
|
|
||||||
# the the gradient is computed and stored.
|
|
||||||
# .step() performs parameter update
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
# Storing the losses in a list for plotting
|
|
||||||
losses.append(loss.detach().cpu())
|
|
||||||
outputs.append((epochs, image, reconstructed))
|
|
||||||
|
|
||||||
# Defining the Plot Style
|
|
||||||
plt.style.use('fivethirtyeight')
|
|
||||||
plt.xlabel('Iterations')
|
|
||||||
plt.ylabel('Loss')
|
|
||||||
|
|
||||||
# Plotting the last 100 values
|
|
||||||
print(losses)
|
|
||||||
plt.plot(losses[-100:])
|
|
||||||
plt.show()
|
|
@ -18,13 +18,12 @@ dataset = datasets.MNIST(root = "./data",
|
|||||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
loader = torch.utils.data.DataLoader(dataset = dataset,
|
||||||
batch_size = 32,
|
batch_size = 32,
|
||||||
shuffle = True)
|
shuffle = True)
|
||||||
|
# Creating a PyTorch class
|
||||||
# Creating a PyTorch class
|
|
||||||
# 28*28 ==> 9 ==> 28*28
|
# 28*28 ==> 9 ==> 28*28
|
||||||
class AE(torch.nn.Module):
|
class VAE(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, latent_dim=5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.latent_dim = latent_dim
|
||||||
# Building an linear encoder with Linear
|
# Building an linear encoder with Linear
|
||||||
# layer followed by Relu activation function
|
# layer followed by Relu activation function
|
||||||
# 784 ==> 9
|
# 784 ==> 9
|
||||||
@ -37,8 +36,9 @@ class AE(torch.nn.Module):
|
|||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(36, 18),
|
torch.nn.Linear(36, 18),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(18, 9)
|
|
||||||
)
|
)
|
||||||
|
self.mean_embedding = torch.nn.Linear(18, self.latent_dim)
|
||||||
|
self.logvar_embedding = torch.nn.Linear(18, self.latent_dim)
|
||||||
|
|
||||||
# Building an linear decoder with Linear
|
# Building an linear decoder with Linear
|
||||||
# layer followed by Relu activation function
|
# layer followed by Relu activation function
|
||||||
@ -46,7 +46,7 @@ class AE(torch.nn.Module):
|
|||||||
# outputs the value between 0 and 1
|
# outputs the value between 0 and 1
|
||||||
# 9 ==> 784
|
# 9 ==> 784
|
||||||
self.decoder = torch.nn.Sequential(
|
self.decoder = torch.nn.Sequential(
|
||||||
torch.nn.Linear(9, 18),
|
torch.nn.Linear(self.latent_dim, 18),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(18, 36),
|
torch.nn.Linear(18, 36),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
@ -58,31 +58,45 @@ class AE(torch.nn.Module):
|
|||||||
torch.nn.Sigmoid()
|
torch.nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
encoded = self.encoder(x)
|
encoded = self.encoder(x)
|
||||||
decoded = self.decoder(encoded)
|
mean = self.mean_embedding(encoded)
|
||||||
return decoded
|
logvar = self.logvar_embedding(encoded)
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
eps = torch.randn(batch_size, self.latent_dim)
|
||||||
|
z = mean + torch.exp(logvar / 2) * eps
|
||||||
|
reconstructed = self.decoder(z)
|
||||||
|
return mean, logvar, reconstructed, x
|
||||||
|
|
||||||
# Model Initialization
|
def train_model():
|
||||||
model = AE()
|
# Model Initialization
|
||||||
# Validation using MSE Loss function
|
model = VAE(latent_dim=16)
|
||||||
loss_function = torch.nn.MSELoss()
|
# Validation using MSE Loss function
|
||||||
# Using an Adam Optimizer with lr = 0.1
|
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.001):
|
||||||
optimizer = torch.optim.Adam(model.parameters(),
|
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
||||||
lr = 1e-1,
|
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
||||||
|
# print(f"KL Error {kl}, Recon Error {recon}")
|
||||||
|
return kl_beta * kl + recon
|
||||||
|
|
||||||
|
# Using an Adam Optimizer with lr = 0.1
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(),
|
||||||
|
lr = 1e-3,
|
||||||
weight_decay = 1e-8)
|
weight_decay = 1e-8)
|
||||||
|
|
||||||
epochs = 10
|
epochs = 100
|
||||||
outputs = []
|
outputs = []
|
||||||
losses = []
|
losses = []
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(epochs)):
|
||||||
for (image, _) in loader:
|
for (image, _) in loader:
|
||||||
# Reshaping the image to (-1, 784)
|
# Reshaping the image to (-1, 784)
|
||||||
image = image.reshape(-1, 28*28)
|
image = image.reshape(-1, 28*28)
|
||||||
# Output of Autoencoder
|
# Output of Autoencoder
|
||||||
reconstructed = model(image)
|
mean, log_var, reconstructed, image = model(image)
|
||||||
# Calculating the loss function
|
# Calculating the loss function
|
||||||
loss = loss_function(reconstructed, image)
|
loss = loss_function(mean, log_var, reconstructed, image)
|
||||||
# The gradients are set to zero,
|
# The gradients are set to zero,
|
||||||
# the the gradient is computed and stored.
|
# the the gradient is computed and stored.
|
||||||
# .step() performs parameter update
|
# .step() performs parameter update
|
||||||
@ -90,13 +104,21 @@ for epoch in tqdm(range(epochs)):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# Storing the losses in a list for plotting
|
# Storing the losses in a list for plotting
|
||||||
|
if torch.isnan(loss):
|
||||||
|
raise Exception()
|
||||||
losses.append(loss.detach().cpu())
|
losses.append(loss.detach().cpu())
|
||||||
outputs.append((epochs, image, reconstructed))
|
outputs.append((epochs, image, reconstructed))
|
||||||
|
|
||||||
# Defining the Plot Style
|
torch.save(model.state_dict(), "saved_models/model.pth")
|
||||||
plt.style.use('fivethirtyeight')
|
|
||||||
plt.xlabel('Iterations')
|
|
||||||
plt.ylabel('Loss')
|
|
||||||
|
|
||||||
# Plotting the last 100 values
|
# Defining the Plot Style
|
||||||
plt.plot(losses[-100:])
|
plt.style.use('fivethirtyeight')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
|
||||||
|
# Plotting the last 100 values
|
||||||
|
plt.plot(losses)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_model()
|
@ -15,15 +15,18 @@ class NeuralNetworkLayer(VGroup):
|
|||||||
"""Handles rendering a layer for a neural network"""
|
"""Handles rendering a layer for a neural network"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_nodes, layer_width=0.2, node_radius=0.12,
|
self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
|
||||||
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
|
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
|
||||||
node_spacing=0.4, rectangle_fill_color=BLACK):
|
node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0,
|
||||||
|
rectangle_stroke_width=2.0):
|
||||||
super(VGroup, self).__init__()
|
super(VGroup, self).__init__()
|
||||||
self.num_nodes = num_nodes
|
self.num_nodes = num_nodes
|
||||||
self.layer_width = layer_width
|
self.layer_buffer = layer_buffer
|
||||||
self.node_radius = node_radius
|
self.node_radius = node_radius
|
||||||
self.node_color = node_color
|
self.node_color = node_color
|
||||||
|
self.node_stroke_width = node_stroke_width
|
||||||
self.node_outline_color = node_outline_color
|
self.node_outline_color = node_outline_color
|
||||||
|
self.rectangle_stroke_width = rectangle_stroke_width
|
||||||
self.rectangle_color = rectangle_color
|
self.rectangle_color = rectangle_color
|
||||||
self.node_spacing = node_spacing
|
self.node_spacing = node_spacing
|
||||||
self.rectangle_fill_color = rectangle_fill_color
|
self.rectangle_fill_color = rectangle_fill_color
|
||||||
@ -36,7 +39,7 @@ class NeuralNetworkLayer(VGroup):
|
|||||||
"""Creates the neural network layer"""
|
"""Creates the neural network layer"""
|
||||||
# Add Nodes
|
# Add Nodes
|
||||||
for node_number in range(self.num_nodes):
|
for node_number in range(self.num_nodes):
|
||||||
node_object = Circle(radius=self.node_radius, color=self.node_color)
|
node_object = Circle(radius=self.node_radius, color=self.node_color, stroke_width=self.node_stroke_width)
|
||||||
self.node_group.add(node_object)
|
self.node_group.add(node_object)
|
||||||
# Space the nodes
|
# Space the nodes
|
||||||
# Assumes Vertical orientation
|
# Assumes Vertical orientation
|
||||||
@ -45,24 +48,28 @@ class NeuralNetworkLayer(VGroup):
|
|||||||
node_object.move_to([0, location, 0])
|
node_object.move_to([0, location, 0])
|
||||||
# Create Surrounding Rectangle
|
# Create Surrounding Rectangle
|
||||||
surrounding_rectangle = SurroundingRectangle(
|
surrounding_rectangle = SurroundingRectangle(
|
||||||
self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color, fill_opacity=1.0)
|
self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color,
|
||||||
|
fill_opacity=1.0, buff=self.layer_buffer, stroke_width=self.rectangle_stroke_width
|
||||||
|
)
|
||||||
# Add the objects to the class
|
# Add the objects to the class
|
||||||
self.add(surrounding_rectangle, self.node_group)
|
self.add(surrounding_rectangle, self.node_group)
|
||||||
|
|
||||||
class NeuralNetwork(VGroup):
|
class NeuralNetwork(VGroup):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, layer_node_count, layer_width=1.0, node_radius=1.0,
|
self, layer_node_count, layer_width=0.6, node_radius=1.0,
|
||||||
node_color=BLUE, edge_color=WHITE, layer_spacing=1.2,
|
node_color=BLUE, edge_color=WHITE, layer_spacing=0.8,
|
||||||
animation_dot_color=RED):
|
animation_dot_color=RED, edge_width=2.0, dot_radius=0.05):
|
||||||
super(VGroup, self).__init__()
|
super(VGroup, self).__init__()
|
||||||
self.layer_node_count = layer_node_count
|
self.layer_node_count = layer_node_count
|
||||||
self.layer_width = layer_width
|
self.layer_width = layer_width
|
||||||
self.node_radius = node_radius
|
self.node_radius = node_radius
|
||||||
|
self.edge_width = edge_width
|
||||||
self.node_color = node_color
|
self.node_color = node_color
|
||||||
self.edge_color = edge_color
|
self.edge_color = edge_color
|
||||||
self.layer_spacing = layer_spacing
|
self.layer_spacing = layer_spacing
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
# TODO take layer_node_count [0, (1, 2), 0]
|
# TODO take layer_node_count [0, (1, 2), 0]
|
||||||
# and make it have explicit distinct subspaces
|
# and make it have explicit distinct subspaces
|
||||||
@ -97,7 +104,7 @@ class NeuralNetwork(VGroup):
|
|||||||
# Go through each node in the two layers and make a connecting line
|
# Go through each node in the two layers and make a connecting line
|
||||||
for node_i in current_layer.node_group:
|
for node_i in current_layer.node_group:
|
||||||
for node_j in next_layer.node_group:
|
for node_j in next_layer.node_group:
|
||||||
line = Line(node_i.get_center(), node_j.get_center(), color=self.edge_color)
|
line = Line(node_i.get_center(), node_j.get_center(), color=self.edge_color, stroke_width=self.edge_width)
|
||||||
edge_layer.add(line)
|
edge_layer.add(line)
|
||||||
edge_layers.add(edge_layer)
|
edge_layers.add(edge_layer)
|
||||||
# Handle layering
|
# Handle layering
|
||||||
@ -112,7 +119,7 @@ class NeuralNetwork(VGroup):
|
|||||||
for edge_layer in self.edge_layers:
|
for edge_layer in self.edge_layers:
|
||||||
path_animations = []
|
path_animations = []
|
||||||
for edge in edge_layer:
|
for edge in edge_layer:
|
||||||
dot = Dot(color=self.animation_dot_color, fill_opacity=1.0, radius=0.06)
|
dot = Dot(color=self.animation_dot_color, fill_opacity=1.0, radius=self.dot_radius)
|
||||||
# Handle layering
|
# Handle layering
|
||||||
dot.set_z_index(1)
|
dot.set_z_index(1)
|
||||||
# Add to dots group
|
# Add to dots group
|
||||||
|
124
src/vae.py
124
src/vae.py
@ -4,46 +4,49 @@ In this module I define Manim visualizations for Variational Autoencoders
|
|||||||
and Traditional Autoencoders.
|
and Traditional Autoencoders.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from configparser import Interpolation
|
||||||
from typing_extensions import runtime
|
from typing_extensions import runtime
|
||||||
from manim import *
|
from manim import *
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import neural_network
|
import neural_network
|
||||||
|
|
||||||
class Autoencoder(VGroup):
|
class VariationalAutoencoder(Group):
|
||||||
"""Traditional Autoencoder Manim Visualization"""
|
"""Variational Autoencoder Manim Visualization"""
|
||||||
|
|
||||||
def __init__(self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE, dot_radius=0.06):
|
def __init__(
|
||||||
super(VGroup, self).__init__()
|
self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE,
|
||||||
|
dot_radius=0.05, ellipse_stroke_width=2.0
|
||||||
|
):
|
||||||
|
super(Group, self).__init__()
|
||||||
self.encoder_nodes_per_layer = encoder_nodes_per_layer
|
self.encoder_nodes_per_layer = encoder_nodes_per_layer
|
||||||
self.decoder_nodes_per_layer = decoder_nodes_per_layer
|
self.decoder_nodes_per_layer = decoder_nodes_per_layer
|
||||||
self.point_color = point_color
|
self.point_color = point_color
|
||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
|
self.ellipse_stroke_width = ellipse_stroke_width
|
||||||
# Make the VMobjects
|
# Make the VMobjects
|
||||||
self.encoder, self.decoder = self._construct_encoder_and_decoder()
|
self.encoder, self.decoder = self._construct_encoder_and_decoder()
|
||||||
self.embedding = self._construct_embedding()
|
self.embedding = self._construct_embedding()
|
||||||
# self.input_image, self.output_image = self._construct_input_output_images()
|
|
||||||
# Setup the relative locations
|
# Setup the relative locations
|
||||||
self.embedding.move_to(self.encoder)
|
self.embedding.move_to(self.encoder)
|
||||||
self.embedding.shift([1.1 * self.encoder.width, 0, 0])
|
self.embedding.shift([1.1 * self.encoder.width, 0, 0])
|
||||||
self.decoder.move_to(self.embedding)
|
self.decoder.move_to(self.embedding)
|
||||||
self.decoder.shift([self.decoder.width * 1.1, 0, 0])
|
self.decoder.shift([self.decoder.width * 1.1, 0, 0])
|
||||||
# self.embedding.shift(self.encoder.width * 1.5)
|
|
||||||
# self.decoder.move_to(self.embedding.get_center())
|
|
||||||
# Add the objects to the VAE object
|
# Add the objects to the VAE object
|
||||||
self.add(self.encoder)
|
self.add(self.encoder)
|
||||||
self.add(self.decoder)
|
self.add(self.decoder)
|
||||||
self.add(self.embedding)
|
self.add(self.embedding)
|
||||||
# self.add(self.input_image)
|
|
||||||
# self.add(self.output_image)
|
|
||||||
|
|
||||||
def _construct_encoder_and_decoder(self):
|
def _construct_encoder_and_decoder(self):
|
||||||
"""Makes the VAE encoder and decoder"""
|
"""Makes the VAE encoder and decoder"""
|
||||||
# Make the encoder
|
# Make the encoder
|
||||||
layer_node_count = self.encoder_nodes_per_layer
|
layer_node_count = self.encoder_nodes_per_layer
|
||||||
encoder = neural_network.NeuralNetwork(layer_node_count)
|
encoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius)
|
||||||
|
encoder.scale(1.2)
|
||||||
# Make the decoder
|
# Make the decoder
|
||||||
layer_node_count = self.decoder_nodes_per_layer
|
layer_node_count = self.decoder_nodes_per_layer
|
||||||
decoder = neural_network.NeuralNetwork(layer_node_count)
|
decoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius)
|
||||||
|
decoder.scale(1.2)
|
||||||
|
|
||||||
return encoder, decoder
|
return encoder, decoder
|
||||||
|
|
||||||
@ -59,55 +62,40 @@ class Autoencoder(VGroup):
|
|||||||
embedding.axes = Axes(
|
embedding.axes = Axes(
|
||||||
x_range=[-3, 3],
|
x_range=[-3, 3],
|
||||||
y_range=[-3, 3],
|
y_range=[-3, 3],
|
||||||
x_length=2.5,
|
x_length=2.2,
|
||||||
y_length=2.5,
|
y_length=2.2,
|
||||||
tips=False,
|
tips=False,
|
||||||
)
|
)
|
||||||
# Add each point to the axes
|
# Add each point to the axes
|
||||||
self.point_dots = VGroup()
|
self.point_dots = VGroup()
|
||||||
for point in points:
|
for point in points:
|
||||||
point_location = embedding.axes.coords_to_point(*point)
|
point_location = embedding.axes.coords_to_point(*point)
|
||||||
dot = Dot(point_location, color=self.point_color, radius=self.dot_radius / 2)
|
dot = Dot(point_location, color=self.point_color, radius=self.dot_radius/2)
|
||||||
self.point_dots.add(dot)
|
self.point_dots.add(dot)
|
||||||
|
|
||||||
embedding.add(self.point_dots)
|
embedding.add(self.point_dots)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def _construct_input_output_images(self, input_output_image_pairs):
|
def _construct_input_output_images(self, image_pair):
|
||||||
"""Places the input and output images for the AE"""
|
"""Places the input and output images for the AE"""
|
||||||
pass
|
# Takes the image pair
|
||||||
|
# image_pair is assumed to be [2, x, y]
|
||||||
|
input_image = image_pair[0][None, :, :]
|
||||||
|
recon_image = image_pair[1][None, :, :]
|
||||||
|
# Convert images to rgb
|
||||||
|
input_image = np.repeat(input_image, 3, axis=0)
|
||||||
|
input_image = np.rollaxis(input_image, 0, start=3)
|
||||||
|
recon_image = np.repeat(recon_image, 3, axis=0)
|
||||||
|
recon_image = np.rollaxis(recon_image, 0, start=3)
|
||||||
|
# Make an image objects
|
||||||
|
input_image_object = ImageMobject(input_image, image_mode="RGB")
|
||||||
|
input_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||||
|
input_image_object.height = 2
|
||||||
|
recon_image_object = ImageMobject(recon_image, image_mode="RGB")
|
||||||
|
recon_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||||
|
recon_image_object.height = 2
|
||||||
|
|
||||||
def make_forward_pass_animation(self, run_time=2):
|
return input_image_object, recon_image_object
|
||||||
"""Makes an animation of a forward pass throgh the VAE"""
|
|
||||||
per_unit_runtime = run_time // 3
|
|
||||||
# Make encoder forward pass
|
|
||||||
encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
|
||||||
# Make red dot in embedding
|
|
||||||
location = [1.0, 1.5]
|
|
||||||
location_point = self.embedding.axes.coords_to_point(*location)
|
|
||||||
# dot = Dot(location_point, color=RED)
|
|
||||||
# create_dot_animation = Create(dot, run_time=per_unit_runtime)
|
|
||||||
# Make decoder foward pass
|
|
||||||
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
|
||||||
# Add the animations to the group
|
|
||||||
animation_group = Succession(
|
|
||||||
encoder_forward_pass,
|
|
||||||
create_dot_animation,
|
|
||||||
decoder_forward_pass,
|
|
||||||
lag_ratio=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
def make_interpolation_animation(self):
|
|
||||||
"""Makes animation of interpolating in the latent space"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class VariationalAutoencoder(Autoencoder):
|
|
||||||
"""Variational Autoencoder Manim Visualization"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def make_dot_convergence_animation(self, location, run_time=1.5):
|
def make_dot_convergence_animation(self, location, run_time=1.5):
|
||||||
"""Makes dots converge on a specific location"""
|
"""Makes dots converge on a specific location"""
|
||||||
@ -141,9 +129,15 @@ class VariationalAutoencoder(Autoencoder):
|
|||||||
animation_group = AnimationGroup(*animations)
|
animation_group = AnimationGroup(*animations)
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
def make_forward_pass_animation(self, run_time=1.5):
|
def make_forward_pass_animation(self, image_pair, run_time=1.5):
|
||||||
"""Overriden forward pass animation specific to a VAE"""
|
"""Overriden forward pass animation specific to a VAE"""
|
||||||
per_unit_runtime = run_time
|
per_unit_runtime = run_time
|
||||||
|
# Setup images
|
||||||
|
self.input_image, self.output_image = self._construct_input_output_images(image_pair)
|
||||||
|
self.input_image.move_to(self.encoder.get_left())
|
||||||
|
self.input_image.shift(LEFT)
|
||||||
|
self.output_image.move_to(self.decoder.get_right())
|
||||||
|
self.output_image.shift(RIGHT * 1.2)
|
||||||
# Make encoder forward pass
|
# Make encoder forward pass
|
||||||
encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
||||||
# Make red dot in embedding
|
# Make red dot in embedding
|
||||||
@ -158,7 +152,7 @@ class VariationalAutoencoder(Autoencoder):
|
|||||||
)
|
)
|
||||||
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
||||||
center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN)
|
center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN)
|
||||||
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5)
|
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5, stroke_width=self.ellipse_stroke_width)
|
||||||
ellipse.move_to(mean_point)
|
ellipse.move_to(mean_point)
|
||||||
ellipse_animation = AnimationGroup(
|
ellipse_animation = AnimationGroup(
|
||||||
GrowFromCenter(center_dot),
|
GrowFromCenter(center_dot),
|
||||||
@ -170,21 +164,47 @@ class VariationalAutoencoder(Autoencoder):
|
|||||||
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
||||||
# Add the animations to the group
|
# Add the animations to the group
|
||||||
animation_group = AnimationGroup(
|
animation_group = AnimationGroup(
|
||||||
|
FadeIn(self.input_image),
|
||||||
encoding_succesion,
|
encoding_succesion,
|
||||||
ellipse_animation,
|
ellipse_animation,
|
||||||
dot_divergence_animation,
|
dot_divergence_animation,
|
||||||
decoder_forward_pass,
|
decoder_forward_pass,
|
||||||
|
FadeIn(self.output_image),
|
||||||
lag_ratio=1,
|
lag_ratio=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTImageHandler():
|
||||||
|
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_pairs_path="src/autoencoder_models/image_pairs.pkl",
|
||||||
|
interpolations_path="src/autoencoder_models/interpolations.pkl"
|
||||||
|
):
|
||||||
|
self.image_pairs_path = image_pairs_path
|
||||||
|
self.interpolations_path = interpolations_path
|
||||||
|
|
||||||
|
self.image_pairs = []
|
||||||
|
self.interpolations = []
|
||||||
|
|
||||||
|
self.load_serialized_data()
|
||||||
|
|
||||||
|
def load_serialized_data(self):
|
||||||
|
with open(self.image_pairs_path, "rb") as f:
|
||||||
|
self.image_pairs = pickle.load(f)
|
||||||
|
|
||||||
|
with open(self.interpolations_path, "rb") as f:
|
||||||
|
self.interpolations_path = pickle.load(f)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The VAE Scene for the twitter video.
|
The VAE Scene for the twitter video.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config.pixel_height = 720
|
config.pixel_height = 720
|
||||||
config.pixel_width = 720
|
config.pixel_width = 1280
|
||||||
config.frame_height = 10.0
|
config.frame_height = 10.0
|
||||||
config.frame_width = 10.0
|
config.frame_width = 10.0
|
||||||
# Set random seed so point distribution is constant
|
# Set random seed so point distribution is constant
|
||||||
@ -196,10 +216,12 @@ class VAEScene(Scene):
|
|||||||
def construct(self):
|
def construct(self):
|
||||||
# Set Scene config
|
# Set Scene config
|
||||||
vae = VariationalAutoencoder()
|
vae = VariationalAutoencoder()
|
||||||
|
mnist_image_handler = MNISTImageHandler()
|
||||||
|
image_pair = mnist_image_handler.image_pairs[2]
|
||||||
vae.move_to(ORIGIN)
|
vae.move_to(ORIGIN)
|
||||||
vae.scale(1.2)
|
vae.scale(1.2)
|
||||||
self.add(vae)
|
self.add(vae)
|
||||||
forward_pass_animation = vae.make_forward_pass_animation()
|
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
||||||
self.play(forward_pass_animation)
|
self.play(forward_pass_animation)
|
||||||
"""
|
"""
|
||||||
autoencoder = Autoencoder()
|
autoencoder = Autoencoder()
|
||||||
|
Reference in New Issue
Block a user