Added updated VAEScene to the readme.

This commit is contained in:
Alec Helbling
2022-04-23 23:10:15 -04:00
parent 0152be64b0
commit 7d04bf55ec
3 changed files with 9 additions and 6 deletions

View File

@ -6,9 +6,10 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
"""NeuralNetwork embedding object that can show probability distributions"""
def __init__(self, point_radius=0.02, mean = np.array([0, 0]),
covariance=np.array([[1.5, 0], [0, 1.5]]), **kwargs):
covariance=np.array([[1.5, 0], [0, 1.5]]), dist_theme="gaussian", **kwargs):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
self.point_radius = point_radius
self.dist_theme = dist_theme
self.axes = Axes(
tips=False,
x_length=0.8,
@ -19,7 +20,8 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance)
self.add(self.point_cloud)
# Make latent distribution
self.latent_distribution = GaussianDistribution(self.axes, mean=mean, cov=covariance) # Use defaults
self.latent_distribution = GaussianDistribution(self.axes, mean=mean, cov=covariance,
dist_theme=self.dist_theme) # Use defaults
def sample_point_location_from_distribution(self):
"""Samples from the current latent distribution"""
@ -49,12 +51,12 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
return point_dots
def make_forward_pass_animation(self, dist_theme="gaussian", **kwargs):
def make_forward_pass_animation(self, **kwargs):
"""Forward pass animation"""
# Make ellipse object corresponding to the latent distribution
self.latent_distribution = GaussianDistribution(
self.axes,
dist_theme=dist_theme,
dist_theme=self.dist_theme,
cov=np.array([[0.8, 0], [0.0, 0.8]])
) # Use defaults
# Create animation