mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-28 09:47:34 +08:00
Changed examples
This commit is contained in:

committed by
Alec Helbling

parent
fe7089abbf
commit
92d8a3d59a
2
Makefile
2
Makefile
@ -1,6 +1,6 @@
|
||||
video:
|
||||
manim -pqh src/vae.py VAEScene --media_dir media
|
||||
cp media/videos/vae/1080p60/VAEScene.mp4 final_videos
|
||||
cp media/videos/vae/720p60/VAEScene.mp4 examples
|
||||
train:
|
||||
cd src/autoencoder_models
|
||||
python vanilla_autoencoder.py
|
||||
|
BIN
examples/TestNeuralNetworkScene.mp4
Normal file
BIN
examples/TestNeuralNetworkScene.mp4
Normal file
Binary file not shown.
BIN
examples/VAEScene.mp4
Normal file
BIN
examples/VAEScene.mp4
Normal file
Binary file not shown.
@ -37,15 +37,14 @@ save_object["interpolation_path"] = interpolation_path
|
||||
for i in range(num_images):
|
||||
# Generate
|
||||
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
||||
gen_image = vae.decode(z)
|
||||
gen_image = vae.decode(z).detach().numpy()
|
||||
gen_image = np.reshape(gen_image, (28, 28)) * 255
|
||||
save_object["interpolation_images"].append(gen_image)
|
||||
|
||||
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
||||
image_pairs = []
|
||||
for i in range(num_images):
|
||||
im = save_object["interpolation_images"][i]
|
||||
im = im.detach().numpy()
|
||||
recon_image = np.reshape(im, (28, 28)) * 255
|
||||
recon_image = save_object["interpolation_images"][i]
|
||||
# Add to plot
|
||||
axs[i].imshow(recon_image)
|
||||
|
||||
|
Binary file not shown.
132
src/vae.py
132
src/vae.py
@ -5,11 +5,13 @@ and Traditional Autoencoders.
|
||||
|
||||
"""
|
||||
from configparser import Interpolation
|
||||
from random import sample
|
||||
from typing_extensions import runtime
|
||||
from manim import *
|
||||
import pickle
|
||||
import numpy as np
|
||||
import neural_network
|
||||
from scipy.interpolate import make_interp_spline
|
||||
|
||||
class VariationalAutoencoder(Group):
|
||||
"""Variational Autoencoder Manim Visualization"""
|
||||
@ -76,24 +78,27 @@ class VariationalAutoencoder(Group):
|
||||
embedding.add(self.point_dots)
|
||||
return embedding
|
||||
|
||||
def _construct_image_mobject(self, input_image, height=2):
|
||||
"""Constructs an ImageMobject from a numpy grayscale image"""
|
||||
# Convert image to rgb
|
||||
input_image = np.repeat(input_image, 3, axis=0)
|
||||
input_image = np.rollaxis(input_image, 0, start=3)
|
||||
# Make the ImageMobject
|
||||
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
||||
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||
image_mobject.height = height
|
||||
|
||||
return image_mobject
|
||||
|
||||
def _construct_input_output_images(self, image_pair):
|
||||
"""Places the input and output images for the AE"""
|
||||
# Takes the image pair
|
||||
# image_pair is assumed to be [2, x, y]
|
||||
input_image = image_pair[0][None, :, :]
|
||||
recon_image = image_pair[1][None, :, :]
|
||||
# Convert images to rgb
|
||||
input_image = np.repeat(input_image, 3, axis=0)
|
||||
input_image = np.rollaxis(input_image, 0, start=3)
|
||||
recon_image = np.repeat(recon_image, 3, axis=0)
|
||||
recon_image = np.rollaxis(recon_image, 0, start=3)
|
||||
# Make an image objects
|
||||
input_image_object = ImageMobject(input_image, image_mode="RGB")
|
||||
input_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||
input_image_object.height = 2
|
||||
recon_image_object = ImageMobject(recon_image, image_mode="RGB")
|
||||
recon_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||
recon_image_object.height = 2
|
||||
# Make the image mobjects
|
||||
input_image_object = self._construct_image_mobject(input_image)
|
||||
recon_image_object = self._construct_image_mobject(recon_image)
|
||||
|
||||
return input_image_object, recon_image_object
|
||||
|
||||
@ -129,6 +134,17 @@ class VariationalAutoencoder(Group):
|
||||
animation_group = AnimationGroup(*animations)
|
||||
return animation_group
|
||||
|
||||
def make_reset_vae_animation(self):
|
||||
"""Resets the VAE to just the neural network"""
|
||||
animation_group = AnimationGroup(
|
||||
FadeOut(self.input_image),
|
||||
FadeOut(self.output_image),
|
||||
FadeOut(self.distribution_objects),
|
||||
run_time=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, image_pair, run_time=1.5):
|
||||
"""Overriden forward pass animation specific to a VAE"""
|
||||
per_unit_runtime = run_time
|
||||
@ -151,15 +167,22 @@ class VariationalAutoencoder(Group):
|
||||
dot_convergence_animation
|
||||
)
|
||||
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
||||
center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN)
|
||||
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5, stroke_width=self.ellipse_stroke_width)
|
||||
center_dot = Dot(mean_point, radius=self.dot_radius, color=RED)
|
||||
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width)
|
||||
ellipse.move_to(mean_point)
|
||||
self.distribution_objects = VGroup(
|
||||
center_dot,
|
||||
ellipse
|
||||
)
|
||||
# Make ellipse animation
|
||||
ellipse_animation = AnimationGroup(
|
||||
GrowFromCenter(center_dot),
|
||||
GrowFromCenter(ellipse),
|
||||
)
|
||||
# Make the dot divergence animation
|
||||
dot_divergence_animation = self.make_dot_divergence_animation(mean_point, run_time=per_unit_runtime)
|
||||
sampled_point = [0.51, 1.0]
|
||||
divergence_point = self.embedding.axes.coords_to_point(*sampled_point)
|
||||
dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime)
|
||||
# Make decoder foward pass
|
||||
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
||||
# Add the animations to the group
|
||||
@ -175,34 +198,77 @@ class VariationalAutoencoder(Group):
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_interpolation_animation(self, interpolation_images, frame_rate=5):
|
||||
"""Makes an animation interpolation"""
|
||||
num_images = len(interpolation_images)
|
||||
# Make madeup path
|
||||
interpolation_latent_path = np.linspace([-0.5, -1], [1, 1.3], num=num_images)
|
||||
# Make the path animation
|
||||
first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0])
|
||||
moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED)
|
||||
animation_list = [GrowFromCenter(moving_dot)]
|
||||
for image_index in range(num_images - 1):
|
||||
next_index = image_index + 1
|
||||
# Get path
|
||||
next_point = interpolation_latent_path[next_index]
|
||||
next_position = self.embedding.axes.coords_to_point(*next_point)
|
||||
# Draw path from current point to next point
|
||||
move_animation = moving_dot.animate.move_to(next_position)
|
||||
animation_list.append(move_animation)
|
||||
|
||||
interpolation_animation = Succession(*animation_list, run_time=0.1*num_images)
|
||||
# Make the images animation
|
||||
animation_list = []
|
||||
for numpy_image in interpolation_images:
|
||||
numpy_image = interpolation_images[0]
|
||||
numpy_image = numpy_image[None, :, :]
|
||||
manim_image = self._construct_image_mobject(numpy_image)
|
||||
# Move the image to the correct location
|
||||
manim_image.move_to(self.output_image)
|
||||
# Add the image
|
||||
animation_list.append(FadeIn(manim_image, run_time=0.1))
|
||||
# Wait
|
||||
animation_list.append(Wait(1 / frame_rate))
|
||||
# Remove the image
|
||||
animation_list.append(FadeOut(manim_image, run_time=0.1))
|
||||
images_animation = Succession(*animation_list)
|
||||
# Combine the two into an AnimationGroup
|
||||
animation_group = AnimationGroup(
|
||||
interpolation_animation,
|
||||
images_animation
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
class MNISTImageHandler():
|
||||
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_pairs_path="src/autoencoder_models/image_pairs.pkl",
|
||||
interpolations_path="src/autoencoder_models/interpolations.pkl"
|
||||
image_pairs_file_path="src/autoencoder_models/image_pairs.pkl",
|
||||
interpolations_file_path="src/autoencoder_models/interpolations.pkl"
|
||||
):
|
||||
self.image_pairs_path = image_pairs_path
|
||||
self.interpolations_path = interpolations_path
|
||||
self.image_pairs_file_path = image_pairs_file_path
|
||||
self.interpolations_file_path = interpolations_file_path
|
||||
|
||||
self.image_pairs = []
|
||||
self.interpolations = []
|
||||
self.interpolation_images = []
|
||||
self.interpolation_latent_path = []
|
||||
|
||||
self.load_serialized_data()
|
||||
|
||||
def load_serialized_data(self):
|
||||
with open(self.image_pairs_path, "rb") as f:
|
||||
with open(self.image_pairs_file_path, "rb") as f:
|
||||
self.image_pairs = pickle.load(f)
|
||||
|
||||
with open(self.interpolations_path, "rb") as f:
|
||||
self.interpolations_path = pickle.load(f)
|
||||
with open(self.interpolations_file_path, "rb") as f:
|
||||
self.interpolation_dict = pickle.load(f)
|
||||
self.interpolation_images = self.interpolation_dict["interpolation_images"]
|
||||
self.interpolation_latent_path = self.interpolation_dict["interpolation_path"]
|
||||
|
||||
"""
|
||||
The VAE Scene for the twitter video.
|
||||
"""
|
||||
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.frame_height = 10.0
|
||||
@ -217,17 +283,17 @@ class VAEScene(Scene):
|
||||
# Set Scene config
|
||||
vae = VariationalAutoencoder()
|
||||
mnist_image_handler = MNISTImageHandler()
|
||||
image_pair = mnist_image_handler.image_pairs[2]
|
||||
image_pair = mnist_image_handler.image_pairs[3]
|
||||
vae.move_to(ORIGIN)
|
||||
vae.scale(1.2)
|
||||
self.add(vae)
|
||||
# Make a forward pass animation
|
||||
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
||||
self.play(forward_pass_animation)
|
||||
"""
|
||||
autoencoder = Autoencoder()
|
||||
autoencoder.move_to(ORIGIN)
|
||||
# Make a forward pass animation
|
||||
self.add(autoencoder)
|
||||
forward_pass_animation = autoencoder.make_forward_pass_animation(run_time=1.5)
|
||||
self.play(forward_pass_animation)
|
||||
"""
|
||||
# Remove the input and output images
|
||||
reset_animation = vae.make_reset_vae_animation()
|
||||
self.play(reset_animation)
|
||||
# Interpolation animation
|
||||
interpolation_images = mnist_image_handler.interpolation_images
|
||||
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
|
||||
self.play(interpolation_animation)
|
||||
|
Reference in New Issue
Block a user