Finshed adding construct_layer methods for refactor.

This commit is contained in:
Alec Helbling
2023-01-15 16:52:22 +09:00
parent 42b6e37b16
commit 99dbda915b
18 changed files with 60 additions and 11 deletions

View File

@ -16,11 +16,20 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
**kwargs
):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
self.mean = mean
self.covariance = covariance
self.gaussian_distributions = VGroup()
self.add(self.gaussian_distributions)
self.point_radius = point_radius
self.dist_theme = dist_theme
self.paired_query_mode = paired_query_mode
def construct_layer(
self,
input_layer: 'NeuralNetworkLayer',
output_layer: 'NeuralNetworkLayer',
**kwargs
):
self.axes = Axes(
tips=False,
x_length=0.8,
@ -33,12 +42,15 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
self.add(self.axes)
self.axes.move_to(self.get_center())
# Make point cloud
self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance)
self.point_cloud = self.construct_gaussian_point_cloud(
self.mean,
self.covariance
)
self.add(self.point_cloud)
# Make latent distribution
self.latent_distribution = GaussianDistribution(
self.axes, mean=mean, cov=covariance
) # Use defaults
self.axes, mean=self.mean, cov=self.covariance
) # Use defaults
def add_gaussian_distribution(self, gaussian_distribution):
"""Adds given GaussianDistribution to the list"""