mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 17:29:45 +08:00
Vector layer (Still work in progress)
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
from tempfile import _TemporaryFileWrapper
|
from tempfile import _TemporaryFileWrapper
|
||||||
from manim_ml.neural_network.layers.paired_query_to_feed_forward import PairedQueryToFeedForward
|
from .feed_forward_to_vector import FeedForwardToVector
|
||||||
|
from .paired_query_to_feed_forward import PairedQueryToFeedForward
|
||||||
from .embedding_to_feed_forward import EmbeddingToFeedForward
|
from .embedding_to_feed_forward import EmbeddingToFeedForward
|
||||||
from .embedding import EmbeddingLayer
|
from .embedding import EmbeddingLayer
|
||||||
from .feed_forward_to_embedding import FeedForwardToEmbedding
|
from .feed_forward_to_embedding import FeedForwardToEmbedding
|
||||||
@ -23,4 +24,5 @@ connective_layers_list = (
|
|||||||
PairedQueryToFeedForward,
|
PairedQueryToFeedForward,
|
||||||
TripletToFeedForward,
|
TripletToFeedForward,
|
||||||
PairedQueryToFeedForward,
|
PairedQueryToFeedForward,
|
||||||
|
FeedForwardToVector,
|
||||||
)
|
)
|
||||||
|
39
manim_ml/neural_network/layers/feed_forward_to_vector.py
Normal file
39
manim_ml/neural_network/layers/feed_forward_to_vector.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||||
|
from manim_ml.neural_network.layers.vector import VectorLayer
|
||||||
|
|
||||||
|
class FeedForwardToVector(ConnectiveLayer):
|
||||||
|
"""Image Layer to FeedForward layer"""
|
||||||
|
input_class = FeedForwardLayer
|
||||||
|
output_class = VectorLayer
|
||||||
|
|
||||||
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||||
|
dot_radius=0.05, **kwargs):
|
||||||
|
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=VectorLayer,
|
||||||
|
**kwargs)
|
||||||
|
self.animation_dot_color = animation_dot_color
|
||||||
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
|
self.feed_forward_layer = input_layer
|
||||||
|
self.vector_layer = output_layer
|
||||||
|
|
||||||
|
def make_forward_pass_animation(self):
|
||||||
|
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||||
|
animations = []
|
||||||
|
# Move the dots to the centers of each of the nodes in the FeedForwardLayer
|
||||||
|
destination = self.vector_layer.get_center()
|
||||||
|
for node in self.feed_forward_layer.node_group:
|
||||||
|
new_dot = Dot(node.get_center(), radius=self.dot_radius, color=self.animation_dot_color)
|
||||||
|
per_node_succession = Succession(
|
||||||
|
Create(new_dot),
|
||||||
|
new_dot.animate.move_to(destination),
|
||||||
|
)
|
||||||
|
animations.append(per_node_succession)
|
||||||
|
|
||||||
|
animation_group = AnimationGroup(*animations)
|
||||||
|
return animation_group
|
||||||
|
|
||||||
|
@override_animation(Create)
|
||||||
|
def _create_override(self):
|
||||||
|
return AnimationGroup()
|
37
manim_ml/neural_network/layers/vector.py
Normal file
37
manim_ml/neural_network/layers/vector.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from manim import *
|
||||||
|
import random
|
||||||
|
|
||||||
|
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
||||||
|
|
||||||
|
class VectorLayer(VGroupNeuralNetworkLayer):
|
||||||
|
"""Shows a vector"""
|
||||||
|
|
||||||
|
def __init__(self, num_values, value_func=lambda: random.uniform(0, 1),
|
||||||
|
**kwargs):
|
||||||
|
print("vector layer")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
print("after init")
|
||||||
|
self.num_values = num_values
|
||||||
|
self.value_func = value_func
|
||||||
|
# Make the vector
|
||||||
|
self.vector_label = self.make_vector()
|
||||||
|
|
||||||
|
def make_vector(self):
|
||||||
|
"""Makes the vector"""
|
||||||
|
if False:
|
||||||
|
# TODO install Latex
|
||||||
|
values = np.array([self.value_func() for i in range(self.num_values)])
|
||||||
|
values = values[None, :].T
|
||||||
|
vector = Matrix(values)
|
||||||
|
|
||||||
|
vector_label = Text(f"[{self.value_func()}]")
|
||||||
|
|
||||||
|
return vector_label
|
||||||
|
|
||||||
|
def make_forward_pass_animation(self):
|
||||||
|
return AnimationGroup()
|
||||||
|
|
||||||
|
@override_animation(Create)
|
||||||
|
def _create_override(self):
|
||||||
|
"""Create animation"""
|
||||||
|
return Create(self.vector_label)
|
Reference in New Issue
Block a user