Files
ManimML/examples/disentanglement/disentanglement.py
2023-01-01 23:24:59 -05:00

124 lines
4.1 KiB
Python

"""This module is dedicated to visualizing VAE disentanglement"""
from pathlib import Path
from manim import *
from manim_ml.neural_network.layers import FeedForwardLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
import pickle
ROOT_DIR = Path(__file__).parents[2]
def construct_image_mobject(input_image, height=2.3):
"""Constructs an ImageMobject from a numpy grayscale image"""
# Convert image to rgb
if len(input_image.shape) == 2:
input_image = np.repeat(input_image, 3, axis=0)
input_image = np.rollaxis(input_image, 0, start=3)
# Make the ImageMobject
image_mobject = ImageMobject(input_image, image_mode="RGB")
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
image_mobject.height = height
return image_mobject
class DisentanglementVisualization(VGroup):
def __init__(
self,
model_path=ROOT_DIR
/ "examples/variational_autoencoder/autoencoder_models/saved_models/model_dim2.pth",
image_height=0.35,
):
self.model_path = model_path
self.image_height = image_height
# Load disentanglement image objects
with open(
ROOT_DIR
/ "examples/variational_autoencoder/autoencoder_models/disentanglement.pkl",
"rb",
) as f:
self.image_handler = pickle.load(f)
def make_disentanglement_generation_animation(self):
animation_list = []
for image_index, image in enumerate(self.image_handler["images"]):
image_mobject = construct_image_mobject(image, height=self.image_height)
r, c = self.image_handler["bin_indices"][image_index]
# Move the image to the correct location
r_offset = -1.2
c_offset = 0.25
image_location = [
c_offset + c * self.image_height,
r_offset + r * self.image_height,
0,
]
image_mobject.move_to(image_location)
animation_list.append(FadeIn(image_mobject))
generation_animation = AnimationGroup(*animation_list[::-1], lag_ratio=1.0)
return generation_animation
config.pixel_height = 720
config.pixel_width = 1280
config.frame_height = 5.0
config.frame_width = 5.0
class DisentanglementScene(Scene):
"""Disentanglement Scene Object"""
def _construct_embedding(self, point_color=BLUE, dot_radius=0.05):
"""Makes a Gaussian-like embedding"""
embedding = VGroup()
# Sample points from a Gaussian
num_points = 200
standard_deviation = [0.6, 0.8]
mean = [0, 0]
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
# Make an axes
embedding.axes = Axes(
x_range=[-3, 3],
y_range=[-3, 3],
x_length=2.2,
y_length=2.2,
tips=False,
)
# Add each point to the axes
self.point_dots = VGroup()
for point in points:
point_location = embedding.axes.coords_to_point(*point)
dot = Dot(point_location, color=point_color, radius=dot_radius / 2)
self.point_dots.add(dot)
embedding.add(self.point_dots)
return embedding
def construct(self):
# Make the VAE decoder
vae_decoder = NeuralNetwork(
[
FeedForwardLayer(3),
FeedForwardLayer(5),
],
layer_spacing=0.55,
)
vae_decoder.shift([-0.55, 0, 0])
self.play(Create(vae_decoder), run_time=1)
# Make the embedding
embedding = self._construct_embedding()
embedding.scale(0.9)
embedding.move_to(vae_decoder.get_left())
embedding.shift([-0.85, 0, 0])
self.play(Create(embedding))
# Make disentanglment visualization
disentanglement = DisentanglementVisualization()
disentanglement_animation = (
disentanglement.make_disentanglement_generation_animation()
)
self.play(disentanglement_animation, run_time=3)
self.play(Wait(2))