mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-15 07:57:41 +08:00
[BUG] update most examples.
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
12
.idea/ManimML.iml
generated
Normal file
12
.idea/ManimML.iml
generated
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.10 (Viz)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
46
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
46
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="false" level="WARNING" enabled_by_default="false">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="5">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="jupyter-core" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="wandb" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="pytorch-lightning" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="torch" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="torchvision" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredErrors">
|
||||||
|
<list>
|
||||||
|
<option value="E722" />
|
||||||
|
<option value="E266" />
|
||||||
|
<option value="W605" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredErrors">
|
||||||
|
<list>
|
||||||
|
<option value="N812" />
|
||||||
|
<option value="N803" />
|
||||||
|
<option value="N806" />
|
||||||
|
<option value="N802" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredIdentifiers">
|
||||||
|
<list>
|
||||||
|
<option value="omegaconf.base.Container.*" />
|
||||||
|
<option value="pytorch_lightning.core.datamodule.LightningDataModule.train_dataset" />
|
||||||
|
<option value="torch.optim.lr_scheduler._LRScheduler" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
4
.idea/misc.xml
generated
Normal file
4
.idea/misc.xml
generated
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Viz)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/ManimML.iml" filepath="$PROJECT_DIR$/.idea/ManimML.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
0
.idea/sonarlint/issuestore/index.pb
generated
Normal file
0
.idea/sonarlint/issuestore/index.pb
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
@ -1,19 +1,23 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from manim import *
|
from manim import *
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from manim_ml.neural_network.layers.convolutional import ConvolutionalLayer
|
from manim_ml.neural_network.layers import Convolutional3DLayer
|
||||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
from manim_ml.neural_network.layers.image import ImageLayer
|
from manim_ml.neural_network.layers.image import ImageLayer
|
||||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parents[2]
|
||||||
|
|
||||||
def make_code_snippet():
|
def make_code_snippet():
|
||||||
code_str = """
|
code_str = """
|
||||||
# Make nn
|
# Make nn
|
||||||
nn = NeuralNetwork([
|
nn = NeuralNetwork([
|
||||||
ImageLayer(numpy_image),
|
ImageLayer(numpy_image),
|
||||||
ConvolutionalLayer(3, 3, 3),
|
Convolutional3DLayer(3, 3, 3),
|
||||||
ConvolutionalLayer(5, 2, 2),
|
Convolutional3DLayer(5, 2, 2),
|
||||||
ConvolutionalLayer(10, 2, 1),
|
Convolutional3DLayer(10, 2, 1),
|
||||||
FeedForwardLayer(3),
|
FeedForwardLayer(3),
|
||||||
FeedForwardLayer(1)
|
FeedForwardLayer(1)
|
||||||
], layer_spacing=0.2)
|
], layer_spacing=0.2)
|
||||||
@ -46,14 +50,14 @@ config.frame_width = 12.0
|
|||||||
|
|
||||||
class CombinedScene(ThreeDScene, Scene):
|
class CombinedScene(ThreeDScene, Scene):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
image = Image.open('../../assets/mnist/digit.jpeg')
|
image = Image.open(ROOT_DIR / 'assets/mnist/digit.jpeg')
|
||||||
numpy_image = np.asarray(image)
|
numpy_image = np.asarray(image)
|
||||||
# Make nn
|
# Make nn
|
||||||
nn = NeuralNetwork([
|
nn = NeuralNetwork([
|
||||||
ImageLayer(numpy_image, height=3.5),
|
ImageLayer(numpy_image, height=3.5),
|
||||||
ConvolutionalLayer(3, 3, 3, filter_spacing=0.2),
|
Convolutional3DLayer(3, 3, 3, filter_spacing=0.2),
|
||||||
ConvolutionalLayer(5, 2, 2, filter_spacing=0.2),
|
Convolutional3DLayer(5, 2, 2, filter_spacing=0.2),
|
||||||
ConvolutionalLayer(10, 2, 1, filter_spacing=0.2),
|
Convolutional3DLayer(10, 2, 1, filter_spacing=0.2),
|
||||||
FeedForwardLayer(3, rectangle_stroke_width=4, node_stroke_width=4).scale(2),
|
FeedForwardLayer(3, rectangle_stroke_width=4, node_stroke_width=4).scale(2),
|
||||||
FeedForwardLayer(1, rectangle_stroke_width=4, node_stroke_width=4).scale(2)
|
FeedForwardLayer(1, rectangle_stroke_width=4, node_stroke_width=4).scale(2)
|
||||||
], layer_spacing=0.2)
|
], layer_spacing=0.2)
|
||||||
|
@ -1,38 +1,42 @@
|
|||||||
"""This module is dedicated to visualizing VAE disentanglement"""
|
"""This module is dedicated to visualizing VAE disentanglement"""
|
||||||
import sys
|
from pathlib import Path
|
||||||
import os
|
|
||||||
sys.path.append(os.environ["PROJECT_ROOT"])
|
|
||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network import NeuralNetwork
|
|
||||||
import manim_ml.util as util
|
from manim_ml.neural_network.layers import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
class VAEDecoder(VGroup):
|
ROOT_DIR = Path(__file__).parents[2]
|
||||||
"""Just shows the VAE encoder"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(VGroup, self).__init__()
|
|
||||||
# Setup the Neural Network
|
|
||||||
node_counts = [3, 5]
|
|
||||||
self.neural_network = NeuralNetwork(node_counts, layer_spacing=0.55)
|
|
||||||
self.add(self.neural_network)
|
|
||||||
|
|
||||||
def make_encoding_animation(self):
|
def construct_image_mobject(input_image, height=2.3):
|
||||||
pass
|
"""Constructs an ImageMobject from a numpy grayscale image"""
|
||||||
|
# Convert image to rgb
|
||||||
|
if len(input_image.shape) == 2:
|
||||||
|
input_image = np.repeat(input_image, 3, axis=0)
|
||||||
|
input_image = np.rollaxis(input_image, 0, start=3)
|
||||||
|
# Make the ImageMobject
|
||||||
|
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
||||||
|
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||||
|
image_mobject.height = height
|
||||||
|
|
||||||
|
return image_mobject
|
||||||
|
|
||||||
class DisentanglementVisualization(VGroup):
|
class DisentanglementVisualization(VGroup):
|
||||||
|
|
||||||
def __init__(self, model_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/saved_models/model_dim2.pth"), image_height=0.35):
|
def __init__(self, model_path=ROOT_DIR / "examples/variational_autoencoder/autoencoder_models/saved_models/model_dim2.pth", image_height=0.35):
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.image_height = image_height
|
self.image_height = image_height
|
||||||
# Load disentanglement image objects
|
# Load disentanglement image objects
|
||||||
with open(os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/disentanglement.pkl"), "rb") as f:
|
with open(ROOT_DIR/ "examples/variational_autoencoder/autoencoder_models/disentanglement.pkl", "rb") as f:
|
||||||
self.image_handler = pickle.load(f)
|
self.image_handler = pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
def make_disentanglement_generation_animation(self):
|
def make_disentanglement_generation_animation(self):
|
||||||
animation_list = []
|
animation_list = []
|
||||||
for image_index, image in enumerate(self.image_handler["images"]):
|
for image_index, image in enumerate(self.image_handler["images"]):
|
||||||
image_mobject = util.construct_image_mobject(image, height=self.image_height)
|
image_mobject = construct_image_mobject(image, height=self.image_height)
|
||||||
r, c = self.image_handler["bin_indices"][image_index]
|
r, c = self.image_handler["bin_indices"][image_index]
|
||||||
# Move the image to the correct location
|
# Move the image to the correct location
|
||||||
r_offset = -1.2
|
r_offset = -1.2
|
||||||
@ -80,7 +84,11 @@ class DisentanglementScene(Scene):
|
|||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
# Make the VAE decoder
|
# Make the VAE decoder
|
||||||
vae_decoder = VAEDecoder()
|
vae_decoder = NeuralNetwork([
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
], layer_spacing=0.55)
|
||||||
|
|
||||||
vae_decoder.shift([-0.55, 0, 0])
|
vae_decoder.shift([-0.55, 0, 0])
|
||||||
self.play(Create(vae_decoder), run_time=1)
|
self.play(Create(vae_decoder), run_time=1)
|
||||||
# Make the embedding
|
# Make the embedding
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||||
@ -8,6 +10,8 @@ from manim_ml.neural_network.layers.vector import VectorLayer
|
|||||||
|
|
||||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parents[2]
|
||||||
|
|
||||||
config.pixel_height = 1080
|
config.pixel_height = 1080
|
||||||
config.pixel_width = 1080
|
config.pixel_width = 1080
|
||||||
config.frame_height = 8.3
|
config.frame_height = 8.3
|
||||||
@ -25,7 +29,7 @@ class GAN(Mobject):
|
|||||||
def make_entities(self, image_height=1.2):
|
def make_entities(self, image_height=1.2):
|
||||||
"""Makes all of the network entities"""
|
"""Makes all of the network entities"""
|
||||||
# Make the fake image layer
|
# Make the fake image layer
|
||||||
default_image = Image.open('../../assets/gan/fake_image.png')
|
default_image = Image.open(ROOT_DIR / 'assets/gan/fake_image.png')
|
||||||
numpy_image = np.asarray(default_image)
|
numpy_image = np.asarray(default_image)
|
||||||
self.fake_image_layer = ImageLayer(numpy_image, height=image_height, show_image_on_create=False)
|
self.fake_image_layer = ImageLayer(numpy_image, height=image_height, show_image_on_create=False)
|
||||||
# Make the Generator Network
|
# Make the Generator Network
|
||||||
@ -45,7 +49,7 @@ class GAN(Mobject):
|
|||||||
], layer_spacing=0.1)
|
], layer_spacing=0.1)
|
||||||
self.add(self.discriminator)
|
self.add(self.discriminator)
|
||||||
# Make Ground Truth Dataset
|
# Make Ground Truth Dataset
|
||||||
default_image = Image.open('../../assets/gan/real_image.jpg')
|
default_image = Image.open(ROOT_DIR / 'assets/gan/real_image.jpg')
|
||||||
numpy_image = np.asarray(default_image)
|
numpy_image = np.asarray(default_image)
|
||||||
self.ground_truth_layer = ImageLayer(numpy_image, height=image_height)
|
self.ground_truth_layer = ImageLayer(numpy_image, height=image_height)
|
||||||
self.add(self.ground_truth_layer)
|
self.add(self.ground_truth_layer)
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
|
|
||||||
"""Visualization of VAE Interpolation"""
|
"""Visualization of VAE Interpolation"""
|
||||||
import sys
|
from pathlib import Path
|
||||||
import os
|
|
||||||
sys.path.append(os.environ["PROJECT_ROOT"])
|
|
||||||
from manim import *
|
from manim import *
|
||||||
import pickle
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import manim_ml.neural_network as neural_network
|
from PIL import Image
|
||||||
import examples.variational_autoencoder.variational_autoencoder as variational_autoencoder
|
from manim_ml.neural_network.layers import EmbeddingLayer
|
||||||
|
from manim_ml.neural_network.layers import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers import ImageLayer
|
||||||
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parents[2]
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The VAE Scene for the twitter video.
|
The VAE Scene for the twitter video.
|
||||||
@ -24,7 +28,17 @@ class InterpolationScene(MovingCameraScene):
|
|||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
# Set Scene config
|
# Set Scene config
|
||||||
vae = variational_autoencoder.VariationalAutoencoder(dot_radius=0.035, layer_spacing=0.5)
|
numpy_image = np.asarray(Image.open(ROOT_DIR / 'assets/mnist/digit.jpeg'))
|
||||||
|
vae = NeuralNetwork([
|
||||||
|
ImageLayer(numpy_image, height=1.4),
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
EmbeddingLayer(dist_theme="ellipse").scale(2),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
ImageLayer(numpy_image, height=1.4),
|
||||||
|
])
|
||||||
|
|
||||||
vae.move_to(ORIGIN)
|
vae.move_to(ORIGIN)
|
||||||
vae.encoder.shift(LEFT*0.5)
|
vae.encoder.shift(LEFT*0.5)
|
||||||
vae.decoder.shift(RIGHT*0.5)
|
vae.decoder.shift(RIGHT*0.5)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Here is a animated explanatory figure for the "Oracle Guided Image Synthesis with Relative Queries" paper.
|
Here is a animated explanatory figure for the "Oracle Guided Image Synthesis with Relative Queries" paper.
|
||||||
"""
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network.layers import triplet
|
from manim_ml.neural_network.layers import triplet
|
||||||
from manim_ml.neural_network.layers.image import ImageLayer
|
from manim_ml.neural_network.layers.image import ImageLayer
|
||||||
@ -19,6 +21,8 @@ config.pixel_width = 1900
|
|||||||
config.frame_height = 6.0
|
config.frame_height = 6.0
|
||||||
config.frame_width = 6.0
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parents[3]
|
||||||
|
|
||||||
class Localizer():
|
class Localizer():
|
||||||
"""
|
"""
|
||||||
Holds the localizer object, which contains the queries, images, etc.
|
Holds the localizer object, which contains the queries, images, etc.
|
||||||
@ -30,8 +34,8 @@ class Localizer():
|
|||||||
self.index = -1
|
self.index = -1
|
||||||
self.axes = axes
|
self.axes = axes
|
||||||
self.num_queries = 3
|
self.num_queries = 3
|
||||||
self.assets_path = "../../../assets/oracle_guidance"
|
self.assets_path = ROOT_DIR / "assets/oracle_guidance"
|
||||||
self.ground_truth_image_path = os.path.join(self.assets_path, "ground_truth.jpg")
|
self.ground_truth_image_path = self.assets_path / "ground_truth.jpg"
|
||||||
self.ground_truth_location = np.array([2, 3])
|
self.ground_truth_location = np.array([2, 3])
|
||||||
# Prior distribution
|
# Prior distribution
|
||||||
print("initial gaussian")
|
print("initial gaussian")
|
||||||
@ -119,7 +123,7 @@ class OracleGuidanceVisualization(Scene):
|
|||||||
self.title = None
|
self.title = None
|
||||||
# Set image paths
|
# Set image paths
|
||||||
# VAE embedding animation image paths
|
# VAE embedding animation image paths
|
||||||
self.assets_path = "../../../assets/oracle_guidance"
|
self.assets_path = ROOT_DIR / "assets/oracle_guidance"
|
||||||
self.input_embed_image_path = os.path.join(self.assets_path, "input_image.jpg")
|
self.input_embed_image_path = os.path.join(self.assets_path, "input_image.jpg")
|
||||||
self.output_embed_image_path = os.path.join(self.assets_path, "output_image.jpg")
|
self.output_embed_image_path = os.path.join(self.assets_path, "output_image.jpg")
|
||||||
|
|
||||||
|
@ -4,323 +4,36 @@ In this module I define Manim visualizations for Variational Autoencoders
|
|||||||
and Traditional Autoencoders.
|
and Traditional Autoencoders.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from manim import *
|
from manim import *
|
||||||
import pickle
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import manim_ml.neural_network as neural_network
|
from manim_ml.neural_network.layers import EmbeddingLayer
|
||||||
from manim_ml.neural_network.embedding import EmbeddingLayer
|
from manim_ml.neural_network.layers import FeedForwardLayer
|
||||||
from manim_ml.neural_network.feed_forward import FeedForwardLayer
|
from manim_ml.neural_network.layers import ImageLayer
|
||||||
from manim_ml.neural_network.image import ImageLayer
|
|
||||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
class VariationalAutoencoder(VGroup):
|
ROOT_DIR = Path(__file__).parents[2]
|
||||||
"""Variational Autoencoder Manim Visualization"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE,
|
|
||||||
dot_radius=0.05, ellipse_stroke_width=2.0, layer_spacing=0.5
|
|
||||||
):
|
|
||||||
super(VGroup, self).__init__()
|
|
||||||
self.encoder_nodes_per_layer = encoder_nodes_per_layer
|
|
||||||
self.decoder_nodes_per_layer = decoder_nodes_per_layer
|
|
||||||
self.point_color = point_color
|
|
||||||
self.dot_radius = dot_radius
|
|
||||||
self.layer_spacing = layer_spacing
|
|
||||||
self.ellipse_stroke_width = ellipse_stroke_width
|
|
||||||
# Make the VMobjects
|
|
||||||
self.encoder, self.decoder = self._construct_encoder_and_decoder()
|
|
||||||
self.embedding = self._construct_embedding()
|
|
||||||
# Setup the relative locations
|
|
||||||
self.embedding.move_to(self.encoder)
|
|
||||||
self.embedding.shift([1.4 * self.encoder.width, 0, 0])
|
|
||||||
self.decoder.move_to(self.embedding)
|
|
||||||
self.decoder.shift([self.decoder.width * 1.4, 0, 0])
|
|
||||||
# Add the objects to the VAE object
|
|
||||||
self.add(self.encoder)
|
|
||||||
self.add(self.decoder)
|
|
||||||
self.add(self.embedding)
|
|
||||||
|
|
||||||
def _construct_encoder_and_decoder(self):
|
|
||||||
"""Makes the VAE encoder and decoder"""
|
|
||||||
# Make the encoder
|
|
||||||
layer_node_count = self.encoder_nodes_per_layer
|
|
||||||
encoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius, layer_spacing=self.layer_spacing)
|
|
||||||
encoder.scale(1.2)
|
|
||||||
# Make the decoder
|
|
||||||
layer_node_count = self.decoder_nodes_per_layer
|
|
||||||
decoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius, layer_spacing=self.layer_spacing)
|
|
||||||
decoder.scale(1.2)
|
|
||||||
|
|
||||||
return encoder, decoder
|
|
||||||
|
|
||||||
def _construct_embedding(self):
|
|
||||||
"""Makes a Gaussian-like embedding"""
|
|
||||||
embedding = VGroup()
|
|
||||||
# Sample points from a Gaussian
|
|
||||||
num_points = 200
|
|
||||||
standard_deviation = [0.9, 0.9]
|
|
||||||
mean = [0, 0]
|
|
||||||
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
|
|
||||||
# Make an axes
|
|
||||||
embedding.axes = Axes(
|
|
||||||
x_range=[-3, 3],
|
|
||||||
y_range=[-3, 3],
|
|
||||||
x_length=2.2,
|
|
||||||
y_length=2.2,
|
|
||||||
tips=False,
|
|
||||||
)
|
|
||||||
# Add each point to the axes
|
|
||||||
self.point_dots = VGroup()
|
|
||||||
for point in points:
|
|
||||||
point_location = embedding.axes.coords_to_point(*point)
|
|
||||||
dot = Dot(point_location, color=self.point_color, radius=self.dot_radius/2)
|
|
||||||
self.point_dots.add(dot)
|
|
||||||
|
|
||||||
embedding.add(self.point_dots)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def _construct_image_mobject(self, input_image, height=2.3):
|
|
||||||
"""Constructs an ImageMobject from a numpy grayscale image"""
|
|
||||||
# Convert image to rgb
|
|
||||||
input_image = np.repeat(input_image, 3, axis=0)
|
|
||||||
input_image = np.rollaxis(input_image, 0, start=3)
|
|
||||||
# Make the ImageMobject
|
|
||||||
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
|
||||||
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
|
||||||
image_mobject.height = height
|
|
||||||
|
|
||||||
return image_mobject
|
|
||||||
|
|
||||||
def _construct_input_output_images(self, image_pair):
|
|
||||||
"""Places the input and output images for the AE"""
|
|
||||||
# Takes the image pair
|
|
||||||
# image_pair is assumed to be [2, x, y]
|
|
||||||
input_image = image_pair[0][None, :, :]
|
|
||||||
recon_image = image_pair[1][None, :, :]
|
|
||||||
# Make the image mobjects
|
|
||||||
input_image_object = self._construct_image_mobject(input_image)
|
|
||||||
recon_image_object = self._construct_image_mobject(recon_image)
|
|
||||||
|
|
||||||
return input_image_object, recon_image_object
|
|
||||||
|
|
||||||
def make_dot_convergence_animation(self, location, run_time=1.5):
|
|
||||||
"""Makes dots converge on a specific location"""
|
|
||||||
# Move to location
|
|
||||||
animations = []
|
|
||||||
for dot in self.encoder.dots:
|
|
||||||
coords = self.embedding.axes.coords_to_point(*location)
|
|
||||||
animations.append(dot.animate.move_to(coords))
|
|
||||||
move_animations = AnimationGroup(*animations, run_time=1.5)
|
|
||||||
# Follow up with remove animations
|
|
||||||
remove_animations = []
|
|
||||||
for dot in self.encoder.dots:
|
|
||||||
remove_animations.append(FadeOut(dot))
|
|
||||||
remove_animations = AnimationGroup(*remove_animations, run_time=0.2)
|
|
||||||
|
|
||||||
animation_group = Succession(move_animations, remove_animations, lag_ratio=1.0)
|
|
||||||
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
def make_dot_divergence_animation(self, location, run_time=3.0):
|
|
||||||
"""Makes dots diverge from the given location and move the decoder"""
|
|
||||||
animations = []
|
|
||||||
for node in self.decoder.layers[0].node_group:
|
|
||||||
new_dot = Dot(location, radius=self.dot_radius, color=RED)
|
|
||||||
per_node_succession = Succession(
|
|
||||||
Create(new_dot),
|
|
||||||
new_dot.animate.move_to(node.get_center()),
|
|
||||||
)
|
|
||||||
animations.append(per_node_succession)
|
|
||||||
|
|
||||||
animation_group = AnimationGroup(*animations)
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
def make_reset_vae_animation(self):
|
|
||||||
"""Resets the VAE to just the neural network"""
|
|
||||||
animation_group = AnimationGroup(
|
|
||||||
FadeOut(self.input_image),
|
|
||||||
FadeOut(self.output_image),
|
|
||||||
FadeOut(self.distribution_objects),
|
|
||||||
run_time=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
def make_forward_pass_animation(self, image_pair, run_time=1.5):
|
|
||||||
"""Overriden forward pass animation specific to a VAE"""
|
|
||||||
per_unit_runtime = run_time
|
|
||||||
# Setup images
|
|
||||||
self.input_image, self.output_image = self._construct_input_output_images(image_pair)
|
|
||||||
self.input_image.move_to(self.encoder.get_left())
|
|
||||||
self.input_image.shift(LEFT)
|
|
||||||
self.output_image.move_to(self.decoder.get_right())
|
|
||||||
self.output_image.shift(RIGHT*1.3)
|
|
||||||
# Make encoder forward pass
|
|
||||||
encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
|
||||||
# Make red dot in embedding
|
|
||||||
mean = [1.0, 1.5]
|
|
||||||
mean_point = self.embedding.axes.coords_to_point(*mean)
|
|
||||||
std = [0.8, 1.2]
|
|
||||||
# Make the dot convergence animation
|
|
||||||
dot_convergence_animation = self.make_dot_convergence_animation(mean, run_time=per_unit_runtime)
|
|
||||||
encoding_succesion = Succession(
|
|
||||||
encoder_forward_pass,
|
|
||||||
dot_convergence_animation
|
|
||||||
)
|
|
||||||
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
|
||||||
center_dot = Dot(mean_point, radius=self.dot_radius, color=RED)
|
|
||||||
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width)
|
|
||||||
ellipse.move_to(mean_point)
|
|
||||||
self.distribution_objects = VGroup(
|
|
||||||
center_dot,
|
|
||||||
ellipse
|
|
||||||
)
|
|
||||||
# Make ellipse animation
|
|
||||||
ellipse_animation = AnimationGroup(
|
|
||||||
GrowFromCenter(center_dot),
|
|
||||||
GrowFromCenter(ellipse),
|
|
||||||
)
|
|
||||||
# Make the dot divergence animation
|
|
||||||
sampled_point = [0.51, 1.0]
|
|
||||||
divergence_point = self.embedding.axes.coords_to_point(*sampled_point)
|
|
||||||
dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime)
|
|
||||||
# Make decoder foward pass
|
|
||||||
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
|
||||||
# Add the animations to the group
|
|
||||||
animation_group = AnimationGroup(
|
|
||||||
FadeIn(self.input_image),
|
|
||||||
encoding_succesion,
|
|
||||||
ellipse_animation,
|
|
||||||
dot_divergence_animation,
|
|
||||||
decoder_forward_pass,
|
|
||||||
FadeIn(self.output_image),
|
|
||||||
lag_ratio=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
def make_interpolation_animation(self, interpolation_images, frame_rate=5):
|
|
||||||
"""Makes an animation interpolation"""
|
|
||||||
num_images = len(interpolation_images)
|
|
||||||
# Make madeup path
|
|
||||||
interpolation_latent_path = np.linspace([-0.7, -1.2], [1.2, 1.5], num=num_images)
|
|
||||||
# Make the path animation
|
|
||||||
first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0])
|
|
||||||
last_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[-1])
|
|
||||||
moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED)
|
|
||||||
self.add(moving_dot)
|
|
||||||
animation_list = [Create(Line(first_dot_location, last_dot_location, color=RED), run_time=0.1*num_images)]
|
|
||||||
for image_index in range(num_images - 1):
|
|
||||||
next_index = image_index + 1
|
|
||||||
# Get path
|
|
||||||
next_point = interpolation_latent_path[next_index]
|
|
||||||
next_position = self.embedding.axes.coords_to_point(*next_point)
|
|
||||||
# Draw path from current point to next point
|
|
||||||
move_animation = moving_dot.animate(run_time=0.1*num_images).move_to(next_position)
|
|
||||||
animation_list.append(move_animation)
|
|
||||||
|
|
||||||
interpolation_animation = AnimationGroup(*animation_list)
|
|
||||||
# Make the images animation
|
|
||||||
animation_list = [Wait(0.5)]
|
|
||||||
for numpy_image in interpolation_images:
|
|
||||||
numpy_image = numpy_image[None, :, :]
|
|
||||||
manim_image = self._construct_image_mobject(numpy_image)
|
|
||||||
# Move the image to the correct location
|
|
||||||
manim_image.move_to(self.output_image)
|
|
||||||
# Add the image
|
|
||||||
animation_list.append(FadeIn(manim_image, run_time=0.1))
|
|
||||||
# Wait
|
|
||||||
# animation_list.append(Wait(1 / frame_rate))
|
|
||||||
# Remove the image
|
|
||||||
# animation_list.append(FadeOut(manim_image, run_time=0.1))
|
|
||||||
images_animation = AnimationGroup(*animation_list, lag_ratio=1.0)
|
|
||||||
# Combine the two into an AnimationGroup
|
|
||||||
animation_group = AnimationGroup(
|
|
||||||
interpolation_animation,
|
|
||||||
images_animation
|
|
||||||
)
|
|
||||||
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
class VariationalAutoencoder(VGroup):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
embedding_layer = EmbeddingLayer()
|
|
||||||
|
|
||||||
image = Image.open('images/image.jpeg')
|
|
||||||
numpy_image = np.asarray(image)
|
|
||||||
# Make nn
|
|
||||||
neural_network = NeuralNetwork([
|
|
||||||
ImageLayer(numpy_image, height=1.4),
|
|
||||||
FeedForwardLayer(5),
|
|
||||||
FeedForwardLayer(3),
|
|
||||||
embedding_layer,
|
|
||||||
FeedForwardLayer(3),
|
|
||||||
FeedForwardLayer(5),
|
|
||||||
ImageLayer(numpy_image, height=1.4),
|
|
||||||
])
|
|
||||||
|
|
||||||
neural_network.scale(1.3)
|
|
||||||
|
|
||||||
self.play(Create(neural_network))
|
|
||||||
self.play(neural_network.make_forward_pass_animation(run_time=15))
|
|
||||||
|
|
||||||
class MNISTImageHandler():
|
|
||||||
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_pairs_file_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/image_pairs.pkl"),
|
|
||||||
interpolations_file_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/interpolations.pkl")
|
|
||||||
):
|
|
||||||
self.image_pairs_file_path = image_pairs_file_path
|
|
||||||
self.interpolations_file_path = interpolations_file_path
|
|
||||||
|
|
||||||
self.image_pairs = []
|
|
||||||
self.interpolation_images = []
|
|
||||||
self.interpolation_latent_path = []
|
|
||||||
|
|
||||||
self.load_serialized_data()
|
|
||||||
|
|
||||||
def load_serialized_data(self):
|
|
||||||
with open(self.image_pairs_file_path, "rb") as f:
|
|
||||||
self.image_pairs = pickle.load(f)
|
|
||||||
|
|
||||||
with open(self.interpolations_file_path, "rb") as f:
|
|
||||||
self.interpolation_dict = pickle.load(f)
|
|
||||||
self.interpolation_images = self.interpolation_dict["interpolation_images"]
|
|
||||||
self.interpolation_latent_path = self.interpolation_dict["interpolation_path"]
|
|
||||||
|
|
||||||
"""
|
|
||||||
The VAE Scene for the twitter video.
|
|
||||||
"""
|
|
||||||
config.pixel_height = 720
|
|
||||||
config.pixel_width = 1280
|
|
||||||
config.frame_height = 5.0
|
|
||||||
config.frame_width = 5.0
|
|
||||||
# Set random seed so point distribution is constant
|
|
||||||
np.random.seed(1)
|
|
||||||
|
|
||||||
class VAEScene(Scene):
|
class VAEScene(Scene):
|
||||||
"""Scene object for a Variational Autoencoder and Autoencoder"""
|
"""Scene object for a Variational Autoencoder and Autoencoder"""
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
# Set Scene config
|
|
||||||
vae = VariationalAutoencoder()
|
numpy_image = np.asarray(Image.open(ROOT_DIR / 'assets/mnist/digit.jpeg'))
|
||||||
mnist_image_handler = MNISTImageHandler()
|
vae = NeuralNetwork([
|
||||||
image_pair = mnist_image_handler.image_pairs[3]
|
ImageLayer(numpy_image, height=1.4),
|
||||||
vae.move_to(ORIGIN)
|
FeedForwardLayer(5),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
EmbeddingLayer(dist_theme="ellipse").scale(2),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
ImageLayer(numpy_image, height=1.4),
|
||||||
|
])
|
||||||
|
|
||||||
vae.scale(1.3)
|
vae.scale(1.3)
|
||||||
self.play(Create(vae), run_time=3)
|
|
||||||
# Make a forward pass animation
|
self.play(Create(vae))
|
||||||
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
self.play(vae.make_forward_pass_animation(run_time=15))
|
||||||
self.play(forward_pass_animation)
|
|
||||||
# Remove the input and output images
|
|
||||||
reset_animation = vae.make_reset_vae_animation()
|
|
||||||
self.play(reset_animation)
|
|
||||||
# Interpolation animation
|
|
||||||
interpolation_images = mnist_image_handler.interpolation_images
|
|
||||||
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
|
|
||||||
self.play(interpolation_animation)
|
|
||||||
|
Reference in New Issue
Block a user