diff --git a/examples/media/VAEScene.gif b/examples/media/VAEScene.gif index 8337dfc..cb64176 100644 Binary files a/examples/media/VAEScene.gif and b/examples/media/VAEScene.gif differ diff --git a/manim_ml/neural_network/layers/embedding.py b/manim_ml/neural_network/layers/embedding.py index 2a96c98..bcdcf1d 100644 --- a/manim_ml/neural_network/layers/embedding.py +++ b/manim_ml/neural_network/layers/embedding.py @@ -6,9 +6,10 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer): """NeuralNetwork embedding object that can show probability distributions""" def __init__(self, point_radius=0.02, mean = np.array([0, 0]), - covariance=np.array([[1.5, 0], [0, 1.5]]), **kwargs): + covariance=np.array([[1.5, 0], [0, 1.5]]), dist_theme="gaussian", **kwargs): super(VGroupNeuralNetworkLayer, self).__init__(**kwargs) self.point_radius = point_radius + self.dist_theme = dist_theme self.axes = Axes( tips=False, x_length=0.8, @@ -19,7 +20,8 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer): self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance) self.add(self.point_cloud) # Make latent distribution - self.latent_distribution = GaussianDistribution(self.axes, mean=mean, cov=covariance) # Use defaults + self.latent_distribution = GaussianDistribution(self.axes, mean=mean, cov=covariance, + dist_theme=self.dist_theme) # Use defaults def sample_point_location_from_distribution(self): """Samples from the current latent distribution""" @@ -49,12 +51,12 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer): return point_dots - def make_forward_pass_animation(self, dist_theme="gaussian", **kwargs): + def make_forward_pass_animation(self, **kwargs): """Forward pass animation""" # Make ellipse object corresponding to the latent distribution self.latent_distribution = GaussianDistribution( self.axes, - dist_theme=dist_theme, + dist_theme=self.dist_theme, cov=np.array([[0.8, 0], [0.0, 0.8]]) ) # Use defaults # Create animation diff --git a/tests/test_variational_autoencoder.py b/tests/test_variational_autoencoder.py index ede8523..acebe51 100644 --- a/tests/test_variational_autoencoder.py +++ b/tests/test_variational_autoencoder.py @@ -28,5 +28,6 @@ class VariationalAutoencoderScene(Scene): neural_network.scale(1.3) - self.play(Create(neural_network)) - self.play(neural_network.make_forward_pass_animation(run_time=15)) \ No newline at end of file + self.play(Create(neural_network), run_time=3) + self.play(neural_network.make_forward_pass_animation(), run_time=5) + self.play(neural_network.make_forward_pass_animation(), run_time=5) \ No newline at end of file