mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-08 00:54:46 +08:00
Finished VAE Disentanglement video
This commit is contained in:

committed by
Alec Helbling

parent
a7c43d0f8b
commit
ca0e93d657
8
Makefile
8
Makefile
@ -1,3 +1,6 @@
|
||||
setup:
|
||||
conda activate manim
|
||||
export PROJECT_ROOT=$(pwd)
|
||||
video:
|
||||
manim -pqh src/vae.py VAEScene --media_dir media
|
||||
cp media/videos/vae/720p60/VAEScene.mp4 examples
|
||||
@ -5,6 +8,11 @@ train:
|
||||
cd src/autoencoder_models
|
||||
python vanilla_autoencoder.py
|
||||
python variational_autoencoder.py
|
||||
generate_visualizations:
|
||||
cd src/autoencoder_models
|
||||
python generate_images.py
|
||||
python generate_interpolation.py
|
||||
python generate_disentanglement.py
|
||||
checkstyle:
|
||||
pycodestyle src
|
||||
pydocstyle src
|
BIN
examples/DisentanglementScene.gif
Normal file
BIN
examples/DisentanglementScene.gif
Normal file
Binary file not shown.
After ![]() (image error) Size: 148 KiB |
BIN
examples/DisentanglementScene.mp4
Normal file
BIN
examples/DisentanglementScene.mp4
Normal file
Binary file not shown.
BIN
examples/TestNeuralNetworkScene.gif
Normal file
BIN
examples/TestNeuralNetworkScene.gif
Normal file
Binary file not shown.
After ![]() (image error) Size: 137 KiB |
BIN
examples/VAEScene.gif
Normal file
BIN
examples/VAEScene.gif
Normal file
Binary file not shown.
After ![]() (image error) Size: 991 KiB |
BIN
src/autoencoder_models/.DS_Store
vendored
Normal file
BIN
src/autoencoder_models/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
src/autoencoder_models/disentanglement.pkl
Normal file
BIN
src/autoencoder_models/disentanglement.pkl
Normal file
Binary file not shown.
99
src/autoencoder_models/generate_disentanglement.py
Normal file
99
src/autoencoder_models/generate_disentanglement.py
Normal file
@ -0,0 +1,99 @@
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.environ["PROJECT_ROOT"])
|
||||
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy
|
||||
import scipy.stats
|
||||
import cv2
|
||||
|
||||
def binned_images(model_path, num_x_bins=10, plot=False):
|
||||
latent_dim = 2
|
||||
model = load_vae_from_path(model_path, latent_dim)
|
||||
image_dataset = load_dataset(digit=2)
|
||||
# Compute embedding
|
||||
num_images = 500
|
||||
embedding = []
|
||||
images = []
|
||||
for i in range(num_images):
|
||||
image, _ = image_dataset[i]
|
||||
mean, _, recon, _ = model.forward(image)
|
||||
mean = mean.detach().numpy()
|
||||
recon = recon.detach().numpy()
|
||||
recon = recon.reshape(32, 32)
|
||||
images.append(recon.squeeze())
|
||||
if latent_dim > 2:
|
||||
mean = mean[:2]
|
||||
embedding.append(mean)
|
||||
images = np.stack(images)
|
||||
tsne_points = np.array(embedding)
|
||||
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
|
||||
# make vis
|
||||
num_points = np.shape(tsne_points)[0]
|
||||
x_min = np.amin(tsne_points.T[0])
|
||||
y_min = np.amin(tsne_points.T[1])
|
||||
y_max = np.amax(tsne_points.T[1])
|
||||
x_max = np.amax(tsne_points.T[0])
|
||||
# make the bins from the ranges
|
||||
# to keep it square the same width is used for x and y dim
|
||||
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
|
||||
x_bins = x_bins.astype(float)
|
||||
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
|
||||
y_bins = np.linspace(y_min, y_max, num_y_bins)
|
||||
# sort the tsne_points into a 2d histogram
|
||||
tsne_points = tsne_points.squeeze()
|
||||
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
|
||||
# sample one point from each bucket
|
||||
binnumbers = hist_obj.binnumber
|
||||
num_x_bins = np.amax(binnumbers[0]) + 1
|
||||
num_y_bins = np.amax(binnumbers[1]) + 1
|
||||
binnumbers = binnumbers.T
|
||||
# some places have no value in a region
|
||||
used_mask = np.zeros((num_y_bins, num_x_bins))
|
||||
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
|
||||
for i, bin_num in enumerate(list(binnumbers)):
|
||||
used_mask[bin_num[1], bin_num[0]] = 1
|
||||
image_bins[bin_num[1], bin_num[0]] = images[i]
|
||||
# plot a grid of the images
|
||||
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
|
||||
images = []
|
||||
bin_indices = []
|
||||
for y in range(num_y_bins):
|
||||
for x in range(num_x_bins):
|
||||
if used_mask[y, x] > 0.0:
|
||||
image = np.uint8(image_bins[y][x].squeeze()*255)
|
||||
image = np.rollaxis(image, 0, 3)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
axs[num_y_bins - 1 - y][x].imshow(image)
|
||||
images.append(image)
|
||||
bin_indices.append((y, x))
|
||||
axs[y, x].axis('off')
|
||||
if plot:
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
else:
|
||||
return images, bin_indices
|
||||
|
||||
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
|
||||
"""Generates disentanglement visualization and serializes it"""
|
||||
# Disentanglement object
|
||||
disentanglement_object = {}
|
||||
# Make Disentanglement
|
||||
images, bin_indices = binned_images(model_path)
|
||||
disentanglement_object["images"] = images
|
||||
disentanglement_object["bin_indices"] = bin_indices
|
||||
# Serialize Images
|
||||
with open("disentanglement.pkl", "wb") as f:
|
||||
pickle.dump(disentanglement_object, f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
plot = False
|
||||
if plot:
|
||||
model_path = "saved_models/model_dim2.pth"
|
||||
#uniform_image_sample(model_path)
|
||||
binned_images(model_path)
|
||||
else:
|
||||
generate_disentanglement()
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from variational_autoencoder import VAE
|
||||
from variational_autoencoder import VAE, load_dataset
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
@ -10,13 +10,7 @@ import pickle
|
||||
# Load model
|
||||
vae = VAE(latent_dim=16)
|
||||
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.ToTensor()
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = True,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
dataset = load_dataset()
|
||||
# Generate reconstructions
|
||||
num_images = 50
|
||||
image_pairs = []
|
||||
@ -24,8 +18,8 @@ save_object = {"interpolation_path":[], "interpolation_images":[]}
|
||||
|
||||
# Make interpolation path
|
||||
image_a, image_b = dataset[0][0], dataset[1][0]
|
||||
image_a = image_a.view(28*28)
|
||||
image_b = image_b.view(28*28)
|
||||
image_a = image_a.view(32*32)
|
||||
image_b = image_b.view(32*32)
|
||||
z_a, _, _, _ = vae.forward(image_a)
|
||||
z_a = z_a.detach().cpu().numpy()
|
||||
z_b, _, _, _ = vae.forward(image_b)
|
||||
@ -38,7 +32,7 @@ for i in range(num_images):
|
||||
# Generate
|
||||
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
||||
gen_image = vae.decode(z).detach().numpy()
|
||||
gen_image = np.reshape(gen_image, (28, 28)) * 255
|
||||
gen_image = np.reshape(gen_image, (32, 32)) * 255
|
||||
save_object["interpolation_images"].append(gen_image)
|
||||
|
||||
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
||||
|
0
src/autoencoder_models/saved_models/model.pth → src/autoencoder_models/saved_models/model_dim16.pth
0
src/autoencoder_models/saved_models/model.pth → src/autoencoder_models/saved_models/model_dim16.pth
BIN
src/autoencoder_models/saved_models/model_dim2.pth
Normal file
BIN
src/autoencoder_models/saved_models/model_dim2.pth
Normal file
Binary file not shown.
BIN
src/autoencoder_models/saved_models/model_dim2_cnn.pth
Normal file
BIN
src/autoencoder_models/saved_models/model_dim2_cnn.pth
Normal file
Binary file not shown.
BIN
src/autoencoder_models/saved_models/model_dim5.pth
Normal file
BIN
src/autoencoder_models/saved_models/model_dim5.pth
Normal file
Binary file not shown.
@ -1,81 +1,173 @@
|
||||
import torch
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
"""
|
||||
These are utility functions that help to calculate the input and output
|
||||
sizes of convolutional neural networks
|
||||
"""
|
||||
|
||||
def num2tuple(num):
|
||||
return num if isinstance(num, tuple) else (num, num)
|
||||
|
||||
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
|
||||
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
|
||||
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
|
||||
|
||||
return h, w
|
||||
|
||||
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
|
||||
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
|
||||
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
|
||||
|
||||
return h, w
|
||||
|
||||
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
|
||||
|
||||
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
|
||||
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
|
||||
|
||||
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
|
||||
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def load_dataset(train=True, digit=None):
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.Compose([
|
||||
transforms.Pad(2),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = train,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
# Load specific image
|
||||
if not digit is None:
|
||||
idx = dataset.train_labels == digit
|
||||
dataset.targets = dataset.targets[idx]
|
||||
dataset.data = dataset.data[idx]
|
||||
|
||||
return dataset
|
||||
|
||||
def load_vae_from_path(path, latent_dim):
|
||||
model = VAE(latent_dim)
|
||||
model.load_state_dict(torch.load(path))
|
||||
|
||||
return model
|
||||
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.ToTensor()
|
||||
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = True,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
|
||||
# DataLoader is used to load the dataset
|
||||
# for training
|
||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
||||
batch_size = 32,
|
||||
shuffle = True)
|
||||
# Creating a PyTorch class
|
||||
# 28*28 ==> 9 ==> 28*28
|
||||
class VAE(torch.nn.Module):
|
||||
def __init__(self, latent_dim=5):
|
||||
def __init__(self, latent_dim=5, layer_count=4, channels=1):
|
||||
super().__init__()
|
||||
self.latent_dim = latent_dim
|
||||
# Building an linear encoder with Linear
|
||||
# layer followed by Relu activation function
|
||||
# 784 ==> 9
|
||||
self.encoder = torch.nn.Sequential(
|
||||
torch.nn.Linear(28 * 28, 128),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(128, 64),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(64, 36),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(36, 18),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.mean_embedding = torch.nn.Linear(18, self.latent_dim)
|
||||
self.logvar_embedding = torch.nn.Linear(18, self.latent_dim)
|
||||
self.in_shape = 32
|
||||
self.layer_count = layer_count
|
||||
self.channels = channels
|
||||
self.d = 128
|
||||
mul = 1
|
||||
inputs = self.channels
|
||||
out_sizes = [(self.in_shape, self.in_shape)]
|
||||
for i in range(self.layer_count):
|
||||
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
|
||||
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
|
||||
inputs = self.d * mul
|
||||
mul *= 2
|
||||
|
||||
# Building an linear decoder with Linear
|
||||
# layer followed by Relu activation function
|
||||
# The Sigmoid activation function
|
||||
# outputs the value between 0 and 1
|
||||
# 9 ==> 784
|
||||
self.decoder = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.latent_dim, 18),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(18, 36),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(36, 64),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(64, 128),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(128, 28 * 28),
|
||||
torch.nn.Sigmoid()
|
||||
)
|
||||
self.d_max = inputs
|
||||
self.last_size = out_sizes[-1][-1]
|
||||
self.num_linear = self.last_size ** 2 * self.d_max
|
||||
# Encoder linear layers
|
||||
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
# Decoder linear layer
|
||||
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(z)
|
||||
mul = inputs // self.d // 2
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
inputs = self.d * mul
|
||||
mul //= 2
|
||||
|
||||
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
|
||||
|
||||
def encode(self, x):
|
||||
if len(x.shape) < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if len(x.shape) < 4:
|
||||
x = x.unsqueeze(1)
|
||||
batch_size = x.shape[0]
|
||||
|
||||
for i in range(self.layer_count):
|
||||
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
|
||||
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
mean = self.encoder_mean_linear(x)
|
||||
logvar = self.encoder_logvar_linear(x)
|
||||
|
||||
return mean, logvar
|
||||
|
||||
def decode(self, x):
|
||||
x = x.view(x.shape[0], self.latent_dim)
|
||||
x = self.decoder_linear(x)
|
||||
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
|
||||
#x = self.deconv1_bn(x)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
|
||||
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
encoded = self.encoder(x)
|
||||
mean = self.mean_embedding(encoded)
|
||||
logvar = self.logvar_embedding(encoded)
|
||||
batch_size = x.shape[0]
|
||||
mean, logvar = self.encode(x)
|
||||
eps = torch.randn(batch_size, self.latent_dim)
|
||||
z = mean + torch.exp(logvar / 2) * eps
|
||||
reconstructed = self.decoder(z)
|
||||
reconstructed = self.decode(z)
|
||||
return mean, logvar, reconstructed, x
|
||||
|
||||
def train_model():
|
||||
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
|
||||
dataset = load_dataset(train=True, digit=digit)
|
||||
# DataLoader is used to load the dataset
|
||||
# for training
|
||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
||||
batch_size = 32,
|
||||
shuffle = True)
|
||||
# Model Initialization
|
||||
model = VAE(latent_dim=16)
|
||||
model = VAE(latent_dim=latent_dim)
|
||||
# Validation using MSE Loss function
|
||||
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.001):
|
||||
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
|
||||
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
||||
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
||||
# print(f"KL Error {kl}, Recon Error {recon}")
|
||||
@ -83,16 +175,13 @@ def train_model():
|
||||
|
||||
# Using an Adam Optimizer with lr = 0.1
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
lr = 1e-3,
|
||||
weight_decay = 1e-8)
|
||||
lr = 1e-4,
|
||||
weight_decay = 0e-8)
|
||||
|
||||
epochs = 100
|
||||
outputs = []
|
||||
losses = []
|
||||
for epoch in tqdm(range(epochs)):
|
||||
for (image, _) in loader:
|
||||
# Reshaping the image to (-1, 784)
|
||||
image = image.reshape(-1, 28*28)
|
||||
# Output of Autoencoder
|
||||
mean, log_var, reconstructed, image = model(image)
|
||||
# Calculating the loss function
|
||||
@ -109,16 +198,17 @@ def train_model():
|
||||
losses.append(loss.detach().cpu())
|
||||
outputs.append((epochs, image, reconstructed))
|
||||
|
||||
torch.save(model.state_dict(), "saved_models/model.pth")
|
||||
torch.save(model.state_dict(), f"saved_models/model_dim{latent_dim}.pth")
|
||||
|
||||
# Defining the Plot Style
|
||||
plt.style.use('fivethirtyeight')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# Plotting the last 100 values
|
||||
plt.plot(losses)
|
||||
plt.show()
|
||||
if plot:
|
||||
# Defining the Plot Style
|
||||
plt.style.use('fivethirtyeight')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# Plotting the last 100 values
|
||||
plt.plot(losses)
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_model()
|
||||
train_model(latent_dim=2, digit=1, epochs=40)
|
||||
|
93
src/disentanglement.py
Normal file
93
src/disentanglement.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""This module is dedicated to visualizing VAE disentanglement"""
|
||||
from manim import *
|
||||
from neural_network import NeuralNetwork
|
||||
import util
|
||||
import pickle
|
||||
|
||||
class VAEDecoder(VGroup):
|
||||
"""Just shows the VAE encoder"""
|
||||
|
||||
def __init__(self):
|
||||
super(VGroup, self).__init__()
|
||||
# Setup the Neural Network
|
||||
node_counts = [3, 5]
|
||||
self.neural_network = NeuralNetwork(node_counts, layer_spacing=0.55)
|
||||
self.add(self.neural_network)
|
||||
|
||||
def make_encoding_animation(self):
|
||||
pass
|
||||
|
||||
class DisentanglementVisualization(VGroup):
|
||||
|
||||
def __init__(self, model_path="autoencoder_models/saved_models/model_dim2.pth", image_height=0.2):
|
||||
self.model_path = model_path
|
||||
self.image_height = image_height
|
||||
# Load disentanglement image objects
|
||||
with open("autoencoder_models/disentanglement.pkl", "rb") as f:
|
||||
self.image_handler = pickle.load(f)
|
||||
|
||||
def make_disentanglement_generation_animation(self):
|
||||
animation_list = []
|
||||
for image_index, image in enumerate(self.image_handler["images"]):
|
||||
image_mobject = util.construct_image_mobject(image, height=self.image_height)
|
||||
r, c = self.image_handler["bin_indices"][image_index]
|
||||
# Move the image to the correct location
|
||||
r_offset = -1.2
|
||||
c_offset = 0.2
|
||||
image_location = [c_offset + c*self.image_height, r_offset + r*self.image_height, 0]
|
||||
image_mobject.move_to(image_location)
|
||||
animation_list.append(FadeIn(image_mobject))
|
||||
|
||||
generation_animation = AnimationGroup(*animation_list[::-1], lag_ratio=1.0)
|
||||
return generation_animation
|
||||
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.frame_height = 5.0
|
||||
config.frame_width = 5.0
|
||||
|
||||
class DisentanglementScene(Scene):
|
||||
"""Disentanglement Scene Object"""
|
||||
|
||||
def _construct_embedding(self, point_color=BLUE, dot_radius=0.05):
|
||||
"""Makes a Gaussian-like embedding"""
|
||||
embedding = VGroup()
|
||||
# Sample points from a Gaussian
|
||||
num_points = 200
|
||||
standard_deviation = [0.6, 1.0]
|
||||
mean = [0, 0]
|
||||
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
|
||||
# Make an axes
|
||||
embedding.axes = Axes(
|
||||
x_range=[-3, 3],
|
||||
y_range=[-3, 3],
|
||||
x_length=2.2,
|
||||
y_length=2.2,
|
||||
tips=False,
|
||||
)
|
||||
# Add each point to the axes
|
||||
self.point_dots = VGroup()
|
||||
for point in points:
|
||||
point_location = embedding.axes.coords_to_point(*point)
|
||||
dot = Dot(point_location, color=point_color, radius=dot_radius/2)
|
||||
self.point_dots.add(dot)
|
||||
|
||||
embedding.add(self.point_dots)
|
||||
return embedding
|
||||
|
||||
def construct(self):
|
||||
# Make the VAE decoder
|
||||
vae_decoder = VAEDecoder()
|
||||
vae_decoder.shift([-0.65, 0, 0])
|
||||
self.play(Create(vae_decoder), run_time=1)
|
||||
# Make the embedding
|
||||
embedding = self._construct_embedding()
|
||||
embedding.scale(0.8)
|
||||
embedding.move_to(vae_decoder.get_left())
|
||||
embedding.shift([-0.7, 0, 0])
|
||||
self.play(Create(embedding))
|
||||
# Make disentanglment visualization
|
||||
disentanglement = DisentanglementVisualization()
|
||||
disentanglement_animation = disentanglement.make_disentanglement_generation_animation()
|
||||
self.play(disentanglement_animation, run_time=3)
|
||||
self.play(Wait(2))
|
@ -1,11 +1,11 @@
|
||||
"""Neural Network Manim Visualization
|
||||
|
||||
This module is responsible for generating a neural network visualization with
|
||||
manim, specifically a fully connected neural network diagram.
|
||||
manim, specifically a fully connected neural network diagram.
|
||||
|
||||
Example:
|
||||
# Specify how many nodes are in each node layer
|
||||
layer_node_count = [5, 3, 5]
|
||||
layer_node_count = [5, 3, 5]
|
||||
# Create the object with default style settings
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
@ -15,7 +15,7 @@ class NeuralNetworkLayer(VGroup):
|
||||
"""Handles rendering a layer for a neural network"""
|
||||
|
||||
def __init__(
|
||||
self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
|
||||
self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
|
||||
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
|
||||
node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0,
|
||||
rectangle_stroke_width=2.0):
|
||||
@ -48,7 +48,7 @@ class NeuralNetworkLayer(VGroup):
|
||||
node_object.move_to([0, location, 0])
|
||||
# Create Surrounding Rectangle
|
||||
surrounding_rectangle = SurroundingRectangle(
|
||||
self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color,
|
||||
self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color,
|
||||
fill_opacity=1.0, buff=self.layer_buffer, stroke_width=self.rectangle_stroke_width
|
||||
)
|
||||
# Add the objects to the class
|
||||
@ -57,7 +57,7 @@ class NeuralNetworkLayer(VGroup):
|
||||
class NeuralNetwork(VGroup):
|
||||
|
||||
def __init__(
|
||||
self, layer_node_count, layer_width=0.6, node_radius=1.0,
|
||||
self, layer_node_count, layer_width=0.6, node_radius=1.0,
|
||||
node_color=BLUE, edge_color=WHITE, layer_spacing=0.8,
|
||||
animation_dot_color=RED, edge_width=2.0, dot_radius=0.05):
|
||||
super(VGroup, self).__init__()
|
||||
@ -73,7 +73,6 @@ class NeuralNetwork(VGroup):
|
||||
|
||||
# TODO take layer_node_count [0, (1, 2), 0]
|
||||
# and make it have explicit distinct subspaces
|
||||
|
||||
self.layers = self._construct_layers()
|
||||
self.edge_layers = self._construct_edges()
|
||||
|
||||
@ -99,7 +98,7 @@ class NeuralNetwork(VGroup):
|
||||
edge_layers = VGroup()
|
||||
for layer_index in range(len(self.layer_node_count) - 1):
|
||||
current_layer = self.layers[layer_index]
|
||||
next_layer = self.layers[layer_index + 1]
|
||||
next_layer = self.layers[layer_index + 1]
|
||||
edge_layer = VGroup()
|
||||
# Go through each node in the two layers and make a connecting line
|
||||
for node_i in current_layer.node_group:
|
||||
@ -134,8 +133,8 @@ class NeuralNetwork(VGroup):
|
||||
|
||||
return animation_group
|
||||
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
|
15
src/util.py
Normal file
15
src/util.py
Normal file
@ -0,0 +1,15 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
|
||||
def construct_image_mobject(input_image, height=2.3):
|
||||
"""Constructs an ImageMobject from a numpy grayscale image"""
|
||||
# Convert image to rgb
|
||||
if len(input_image.shape) == 2:
|
||||
input_image = np.repeat(input_image, 3, axis=0)
|
||||
input_image = np.rollaxis(input_image, 0, start=3)
|
||||
# Make the ImageMobject
|
||||
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
||||
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||
image_mobject.height = height
|
||||
|
||||
return image_mobject
|
Reference in New Issue
Block a user