mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 03:57:40 +08:00
Finshed adding construct_layer methods for refactor.
This commit is contained in:
@ -16,11 +16,20 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
**kwargs
|
||||
):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
self.mean = mean
|
||||
self.covariance = covariance
|
||||
self.gaussian_distributions = VGroup()
|
||||
self.add(self.gaussian_distributions)
|
||||
self.point_radius = point_radius
|
||||
self.dist_theme = dist_theme
|
||||
self.paired_query_mode = paired_query_mode
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
**kwargs
|
||||
):
|
||||
self.axes = Axes(
|
||||
tips=False,
|
||||
x_length=0.8,
|
||||
@ -33,12 +42,15 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
self.add(self.axes)
|
||||
self.axes.move_to(self.get_center())
|
||||
# Make point cloud
|
||||
self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance)
|
||||
self.point_cloud = self.construct_gaussian_point_cloud(
|
||||
self.mean,
|
||||
self.covariance
|
||||
)
|
||||
self.add(self.point_cloud)
|
||||
# Make latent distribution
|
||||
self.latent_distribution = GaussianDistribution(
|
||||
self.axes, mean=mean, cov=covariance
|
||||
) # Use defaults
|
||||
self.axes, mean=self.mean, cov=self.covariance
|
||||
) # Use defaults
|
||||
|
||||
def add_gaussian_distribution(self, gaussian_distribution):
|
||||
"""Adds given GaussianDistribution to the list"""
|
||||
|
Reference in New Issue
Block a user