mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-21 04:26:43 +08:00
Added updated VAEScene to the readme.
This commit is contained in:
@ -6,9 +6,10 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
"""NeuralNetwork embedding object that can show probability distributions"""
|
||||
|
||||
def __init__(self, point_radius=0.02, mean = np.array([0, 0]),
|
||||
covariance=np.array([[1.5, 0], [0, 1.5]]), **kwargs):
|
||||
covariance=np.array([[1.5, 0], [0, 1.5]]), dist_theme="gaussian", **kwargs):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
self.point_radius = point_radius
|
||||
self.dist_theme = dist_theme
|
||||
self.axes = Axes(
|
||||
tips=False,
|
||||
x_length=0.8,
|
||||
@ -19,7 +20,8 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance)
|
||||
self.add(self.point_cloud)
|
||||
# Make latent distribution
|
||||
self.latent_distribution = GaussianDistribution(self.axes, mean=mean, cov=covariance) # Use defaults
|
||||
self.latent_distribution = GaussianDistribution(self.axes, mean=mean, cov=covariance,
|
||||
dist_theme=self.dist_theme) # Use defaults
|
||||
|
||||
def sample_point_location_from_distribution(self):
|
||||
"""Samples from the current latent distribution"""
|
||||
@ -49,12 +51,12 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
|
||||
return point_dots
|
||||
|
||||
def make_forward_pass_animation(self, dist_theme="gaussian", **kwargs):
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
"""Forward pass animation"""
|
||||
# Make ellipse object corresponding to the latent distribution
|
||||
self.latent_distribution = GaussianDistribution(
|
||||
self.axes,
|
||||
dist_theme=dist_theme,
|
||||
dist_theme=self.dist_theme,
|
||||
cov=np.array([[0.8, 0], [0.0, 0.8]])
|
||||
) # Use defaults
|
||||
# Create animation
|
||||
|
Reference in New Issue
Block a user