mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-11 04:55:25 +08:00
124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
"""This module is dedicated to visualizing VAE disentanglement"""
|
|
from pathlib import Path
|
|
|
|
from manim import *
|
|
|
|
from manim_ml.neural_network.layers import FeedForwardLayer
|
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
|
import pickle
|
|
|
|
ROOT_DIR = Path(__file__).parents[2]
|
|
|
|
|
|
def construct_image_mobject(input_image, height=2.3):
|
|
"""Constructs an ImageMobject from a numpy grayscale image"""
|
|
# Convert image to rgb
|
|
if len(input_image.shape) == 2:
|
|
input_image = np.repeat(input_image, 3, axis=0)
|
|
input_image = np.rollaxis(input_image, 0, start=3)
|
|
# Make the ImageMobject
|
|
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
|
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
|
image_mobject.height = height
|
|
|
|
return image_mobject
|
|
|
|
|
|
class DisentanglementVisualization(VGroup):
|
|
def __init__(
|
|
self,
|
|
model_path=ROOT_DIR
|
|
/ "examples/variational_autoencoder/autoencoder_models/saved_models/model_dim2.pth",
|
|
image_height=0.35,
|
|
):
|
|
self.model_path = model_path
|
|
self.image_height = image_height
|
|
# Load disentanglement image objects
|
|
with open(
|
|
ROOT_DIR
|
|
/ "examples/variational_autoencoder/autoencoder_models/disentanglement.pkl",
|
|
"rb",
|
|
) as f:
|
|
self.image_handler = pickle.load(f)
|
|
|
|
def make_disentanglement_generation_animation(self):
|
|
animation_list = []
|
|
for image_index, image in enumerate(self.image_handler["images"]):
|
|
image_mobject = construct_image_mobject(image, height=self.image_height)
|
|
r, c = self.image_handler["bin_indices"][image_index]
|
|
# Move the image to the correct location
|
|
r_offset = -1.2
|
|
c_offset = 0.25
|
|
image_location = [
|
|
c_offset + c * self.image_height,
|
|
r_offset + r * self.image_height,
|
|
0,
|
|
]
|
|
image_mobject.move_to(image_location)
|
|
animation_list.append(FadeIn(image_mobject))
|
|
|
|
generation_animation = AnimationGroup(*animation_list[::-1], lag_ratio=1.0)
|
|
return generation_animation
|
|
|
|
|
|
config.pixel_height = 720
|
|
config.pixel_width = 1280
|
|
config.frame_height = 5.0
|
|
config.frame_width = 5.0
|
|
|
|
|
|
class DisentanglementScene(Scene):
|
|
"""Disentanglement Scene Object"""
|
|
|
|
def _construct_embedding(self, point_color=BLUE, dot_radius=0.05):
|
|
"""Makes a Gaussian-like embedding"""
|
|
embedding = VGroup()
|
|
# Sample points from a Gaussian
|
|
num_points = 200
|
|
standard_deviation = [0.6, 0.8]
|
|
mean = [0, 0]
|
|
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
|
|
# Make an axes
|
|
embedding.axes = Axes(
|
|
x_range=[-3, 3],
|
|
y_range=[-3, 3],
|
|
x_length=2.2,
|
|
y_length=2.2,
|
|
tips=False,
|
|
)
|
|
# Add each point to the axes
|
|
self.point_dots = VGroup()
|
|
for point in points:
|
|
point_location = embedding.axes.coords_to_point(*point)
|
|
dot = Dot(point_location, color=point_color, radius=dot_radius / 2)
|
|
self.point_dots.add(dot)
|
|
|
|
embedding.add(self.point_dots)
|
|
return embedding
|
|
|
|
def construct(self):
|
|
# Make the VAE decoder
|
|
vae_decoder = NeuralNetwork(
|
|
[
|
|
FeedForwardLayer(3),
|
|
FeedForwardLayer(5),
|
|
],
|
|
layer_spacing=0.55,
|
|
)
|
|
|
|
vae_decoder.shift([-0.55, 0, 0])
|
|
self.play(Create(vae_decoder), run_time=1)
|
|
# Make the embedding
|
|
embedding = self._construct_embedding()
|
|
embedding.scale(0.9)
|
|
embedding.move_to(vae_decoder.get_left())
|
|
embedding.shift([-0.85, 0, 0])
|
|
self.play(Create(embedding))
|
|
# Make disentanglment visualization
|
|
disentanglement = DisentanglementVisualization()
|
|
disentanglement_animation = (
|
|
disentanglement.make_disentanglement_generation_animation()
|
|
)
|
|
self.play(disentanglement_animation, run_time=3)
|
|
self.play(Wait(2))
|