mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-29 02:07:38 +08:00
Changed examples
This commit is contained in:

committed by
Alec Helbling

parent
fe7089abbf
commit
92d8a3d59a
2
Makefile
2
Makefile
@ -1,6 +1,6 @@
|
|||||||
video:
|
video:
|
||||||
manim -pqh src/vae.py VAEScene --media_dir media
|
manim -pqh src/vae.py VAEScene --media_dir media
|
||||||
cp media/videos/vae/1080p60/VAEScene.mp4 final_videos
|
cp media/videos/vae/720p60/VAEScene.mp4 examples
|
||||||
train:
|
train:
|
||||||
cd src/autoencoder_models
|
cd src/autoencoder_models
|
||||||
python vanilla_autoencoder.py
|
python vanilla_autoencoder.py
|
||||||
|
BIN
examples/TestNeuralNetworkScene.mp4
Normal file
BIN
examples/TestNeuralNetworkScene.mp4
Normal file
Binary file not shown.
BIN
examples/VAEScene.mp4
Normal file
BIN
examples/VAEScene.mp4
Normal file
Binary file not shown.
@ -37,15 +37,14 @@ save_object["interpolation_path"] = interpolation_path
|
|||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
# Generate
|
# Generate
|
||||||
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
||||||
gen_image = vae.decode(z)
|
gen_image = vae.decode(z).detach().numpy()
|
||||||
|
gen_image = np.reshape(gen_image, (28, 28)) * 255
|
||||||
save_object["interpolation_images"].append(gen_image)
|
save_object["interpolation_images"].append(gen_image)
|
||||||
|
|
||||||
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
||||||
image_pairs = []
|
image_pairs = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
im = save_object["interpolation_images"][i]
|
recon_image = save_object["interpolation_images"][i]
|
||||||
im = im.detach().numpy()
|
|
||||||
recon_image = np.reshape(im, (28, 28)) * 255
|
|
||||||
# Add to plot
|
# Add to plot
|
||||||
axs[i].imshow(recon_image)
|
axs[i].imshow(recon_image)
|
||||||
|
|
||||||
|
Binary file not shown.
132
src/vae.py
132
src/vae.py
@ -5,11 +5,13 @@ and Traditional Autoencoders.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
from configparser import Interpolation
|
from configparser import Interpolation
|
||||||
|
from random import sample
|
||||||
from typing_extensions import runtime
|
from typing_extensions import runtime
|
||||||
from manim import *
|
from manim import *
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import neural_network
|
import neural_network
|
||||||
|
from scipy.interpolate import make_interp_spline
|
||||||
|
|
||||||
class VariationalAutoencoder(Group):
|
class VariationalAutoencoder(Group):
|
||||||
"""Variational Autoencoder Manim Visualization"""
|
"""Variational Autoencoder Manim Visualization"""
|
||||||
@ -76,24 +78,27 @@ class VariationalAutoencoder(Group):
|
|||||||
embedding.add(self.point_dots)
|
embedding.add(self.point_dots)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
def _construct_image_mobject(self, input_image, height=2):
|
||||||
|
"""Constructs an ImageMobject from a numpy grayscale image"""
|
||||||
|
# Convert image to rgb
|
||||||
|
input_image = np.repeat(input_image, 3, axis=0)
|
||||||
|
input_image = np.rollaxis(input_image, 0, start=3)
|
||||||
|
# Make the ImageMobject
|
||||||
|
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
||||||
|
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||||
|
image_mobject.height = height
|
||||||
|
|
||||||
|
return image_mobject
|
||||||
|
|
||||||
def _construct_input_output_images(self, image_pair):
|
def _construct_input_output_images(self, image_pair):
|
||||||
"""Places the input and output images for the AE"""
|
"""Places the input and output images for the AE"""
|
||||||
# Takes the image pair
|
# Takes the image pair
|
||||||
# image_pair is assumed to be [2, x, y]
|
# image_pair is assumed to be [2, x, y]
|
||||||
input_image = image_pair[0][None, :, :]
|
input_image = image_pair[0][None, :, :]
|
||||||
recon_image = image_pair[1][None, :, :]
|
recon_image = image_pair[1][None, :, :]
|
||||||
# Convert images to rgb
|
# Make the image mobjects
|
||||||
input_image = np.repeat(input_image, 3, axis=0)
|
input_image_object = self._construct_image_mobject(input_image)
|
||||||
input_image = np.rollaxis(input_image, 0, start=3)
|
recon_image_object = self._construct_image_mobject(recon_image)
|
||||||
recon_image = np.repeat(recon_image, 3, axis=0)
|
|
||||||
recon_image = np.rollaxis(recon_image, 0, start=3)
|
|
||||||
# Make an image objects
|
|
||||||
input_image_object = ImageMobject(input_image, image_mode="RGB")
|
|
||||||
input_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
|
||||||
input_image_object.height = 2
|
|
||||||
recon_image_object = ImageMobject(recon_image, image_mode="RGB")
|
|
||||||
recon_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
|
||||||
recon_image_object.height = 2
|
|
||||||
|
|
||||||
return input_image_object, recon_image_object
|
return input_image_object, recon_image_object
|
||||||
|
|
||||||
@ -129,6 +134,17 @@ class VariationalAutoencoder(Group):
|
|||||||
animation_group = AnimationGroup(*animations)
|
animation_group = AnimationGroup(*animations)
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
|
def make_reset_vae_animation(self):
|
||||||
|
"""Resets the VAE to just the neural network"""
|
||||||
|
animation_group = AnimationGroup(
|
||||||
|
FadeOut(self.input_image),
|
||||||
|
FadeOut(self.output_image),
|
||||||
|
FadeOut(self.distribution_objects),
|
||||||
|
run_time=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
|
||||||
def make_forward_pass_animation(self, image_pair, run_time=1.5):
|
def make_forward_pass_animation(self, image_pair, run_time=1.5):
|
||||||
"""Overriden forward pass animation specific to a VAE"""
|
"""Overriden forward pass animation specific to a VAE"""
|
||||||
per_unit_runtime = run_time
|
per_unit_runtime = run_time
|
||||||
@ -151,15 +167,22 @@ class VariationalAutoencoder(Group):
|
|||||||
dot_convergence_animation
|
dot_convergence_animation
|
||||||
)
|
)
|
||||||
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
||||||
center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN)
|
center_dot = Dot(mean_point, radius=self.dot_radius, color=RED)
|
||||||
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5, stroke_width=self.ellipse_stroke_width)
|
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width)
|
||||||
ellipse.move_to(mean_point)
|
ellipse.move_to(mean_point)
|
||||||
|
self.distribution_objects = VGroup(
|
||||||
|
center_dot,
|
||||||
|
ellipse
|
||||||
|
)
|
||||||
|
# Make ellipse animation
|
||||||
ellipse_animation = AnimationGroup(
|
ellipse_animation = AnimationGroup(
|
||||||
GrowFromCenter(center_dot),
|
GrowFromCenter(center_dot),
|
||||||
GrowFromCenter(ellipse),
|
GrowFromCenter(ellipse),
|
||||||
)
|
)
|
||||||
# Make the dot divergence animation
|
# Make the dot divergence animation
|
||||||
dot_divergence_animation = self.make_dot_divergence_animation(mean_point, run_time=per_unit_runtime)
|
sampled_point = [0.51, 1.0]
|
||||||
|
divergence_point = self.embedding.axes.coords_to_point(*sampled_point)
|
||||||
|
dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime)
|
||||||
# Make decoder foward pass
|
# Make decoder foward pass
|
||||||
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
||||||
# Add the animations to the group
|
# Add the animations to the group
|
||||||
@ -175,34 +198,77 @@ class VariationalAutoencoder(Group):
|
|||||||
|
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
|
def make_interpolation_animation(self, interpolation_images, frame_rate=5):
|
||||||
|
"""Makes an animation interpolation"""
|
||||||
|
num_images = len(interpolation_images)
|
||||||
|
# Make madeup path
|
||||||
|
interpolation_latent_path = np.linspace([-0.5, -1], [1, 1.3], num=num_images)
|
||||||
|
# Make the path animation
|
||||||
|
first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0])
|
||||||
|
moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED)
|
||||||
|
animation_list = [GrowFromCenter(moving_dot)]
|
||||||
|
for image_index in range(num_images - 1):
|
||||||
|
next_index = image_index + 1
|
||||||
|
# Get path
|
||||||
|
next_point = interpolation_latent_path[next_index]
|
||||||
|
next_position = self.embedding.axes.coords_to_point(*next_point)
|
||||||
|
# Draw path from current point to next point
|
||||||
|
move_animation = moving_dot.animate.move_to(next_position)
|
||||||
|
animation_list.append(move_animation)
|
||||||
|
|
||||||
|
interpolation_animation = Succession(*animation_list, run_time=0.1*num_images)
|
||||||
|
# Make the images animation
|
||||||
|
animation_list = []
|
||||||
|
for numpy_image in interpolation_images:
|
||||||
|
numpy_image = interpolation_images[0]
|
||||||
|
numpy_image = numpy_image[None, :, :]
|
||||||
|
manim_image = self._construct_image_mobject(numpy_image)
|
||||||
|
# Move the image to the correct location
|
||||||
|
manim_image.move_to(self.output_image)
|
||||||
|
# Add the image
|
||||||
|
animation_list.append(FadeIn(manim_image, run_time=0.1))
|
||||||
|
# Wait
|
||||||
|
animation_list.append(Wait(1 / frame_rate))
|
||||||
|
# Remove the image
|
||||||
|
animation_list.append(FadeOut(manim_image, run_time=0.1))
|
||||||
|
images_animation = Succession(*animation_list)
|
||||||
|
# Combine the two into an AnimationGroup
|
||||||
|
animation_group = AnimationGroup(
|
||||||
|
interpolation_animation,
|
||||||
|
images_animation
|
||||||
|
)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
|
||||||
class MNISTImageHandler():
|
class MNISTImageHandler():
|
||||||
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_pairs_path="src/autoencoder_models/image_pairs.pkl",
|
image_pairs_file_path="src/autoencoder_models/image_pairs.pkl",
|
||||||
interpolations_path="src/autoencoder_models/interpolations.pkl"
|
interpolations_file_path="src/autoencoder_models/interpolations.pkl"
|
||||||
):
|
):
|
||||||
self.image_pairs_path = image_pairs_path
|
self.image_pairs_file_path = image_pairs_file_path
|
||||||
self.interpolations_path = interpolations_path
|
self.interpolations_file_path = interpolations_file_path
|
||||||
|
|
||||||
self.image_pairs = []
|
self.image_pairs = []
|
||||||
self.interpolations = []
|
self.interpolation_images = []
|
||||||
|
self.interpolation_latent_path = []
|
||||||
|
|
||||||
self.load_serialized_data()
|
self.load_serialized_data()
|
||||||
|
|
||||||
def load_serialized_data(self):
|
def load_serialized_data(self):
|
||||||
with open(self.image_pairs_path, "rb") as f:
|
with open(self.image_pairs_file_path, "rb") as f:
|
||||||
self.image_pairs = pickle.load(f)
|
self.image_pairs = pickle.load(f)
|
||||||
|
|
||||||
with open(self.interpolations_path, "rb") as f:
|
with open(self.interpolations_file_path, "rb") as f:
|
||||||
self.interpolations_path = pickle.load(f)
|
self.interpolation_dict = pickle.load(f)
|
||||||
|
self.interpolation_images = self.interpolation_dict["interpolation_images"]
|
||||||
|
self.interpolation_latent_path = self.interpolation_dict["interpolation_path"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The VAE Scene for the twitter video.
|
The VAE Scene for the twitter video.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config.pixel_height = 720
|
config.pixel_height = 720
|
||||||
config.pixel_width = 1280
|
config.pixel_width = 1280
|
||||||
config.frame_height = 10.0
|
config.frame_height = 10.0
|
||||||
@ -217,17 +283,17 @@ class VAEScene(Scene):
|
|||||||
# Set Scene config
|
# Set Scene config
|
||||||
vae = VariationalAutoencoder()
|
vae = VariationalAutoencoder()
|
||||||
mnist_image_handler = MNISTImageHandler()
|
mnist_image_handler = MNISTImageHandler()
|
||||||
image_pair = mnist_image_handler.image_pairs[2]
|
image_pair = mnist_image_handler.image_pairs[3]
|
||||||
vae.move_to(ORIGIN)
|
vae.move_to(ORIGIN)
|
||||||
vae.scale(1.2)
|
vae.scale(1.2)
|
||||||
self.add(vae)
|
self.add(vae)
|
||||||
|
# Make a forward pass animation
|
||||||
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
||||||
self.play(forward_pass_animation)
|
self.play(forward_pass_animation)
|
||||||
"""
|
# Remove the input and output images
|
||||||
autoencoder = Autoencoder()
|
reset_animation = vae.make_reset_vae_animation()
|
||||||
autoencoder.move_to(ORIGIN)
|
self.play(reset_animation)
|
||||||
# Make a forward pass animation
|
# Interpolation animation
|
||||||
self.add(autoencoder)
|
interpolation_images = mnist_image_handler.interpolation_images
|
||||||
forward_pass_animation = autoencoder.make_forward_pass_animation(run_time=1.5)
|
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
|
||||||
self.play(forward_pass_animation)
|
self.play(interpolation_animation)
|
||||||
"""
|
|
||||||
|
Reference in New Issue
Block a user