mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 17:29:45 +08:00
Working forward pass for triplet layers.
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
from manim import *
|
from manim import *
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
class GrayscaleImageMobject(ImageMobject):
|
class GrayscaleImageMobject(ImageMobject):
|
||||||
"""Mobject for creating images in Manim from numpy arrays"""
|
"""Mobject for creating images in Manim from numpy arrays"""
|
||||||
@ -18,6 +19,14 @@ class GrayscaleImageMobject(ImageMobject):
|
|||||||
self.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
self.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||||
self.scale_to_fit_height(height)
|
self.scale_to_fit_height(height)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, path, height=2.3):
|
||||||
|
"""Loads image from path"""
|
||||||
|
image = Image.open(path)
|
||||||
|
numpy_image = np.asarray(image)
|
||||||
|
|
||||||
|
return cls(numpy_image, height=height)
|
||||||
|
|
||||||
@override_animation(Create)
|
@override_animation(Create)
|
||||||
def create(self, run_time=2):
|
def create(self, run_time=2):
|
||||||
return FadeIn(self)
|
return FadeIn(self)
|
||||||
|
@ -6,4 +6,6 @@ from .feed_forward_to_image import FeedForwardToImage
|
|||||||
from .feed_forward import FeedForwardLayer
|
from .feed_forward import FeedForwardLayer
|
||||||
from .image_to_feed_forward import ImageToFeedForward
|
from .image_to_feed_forward import ImageToFeedForward
|
||||||
from .image import ImageLayer
|
from .image import ImageLayer
|
||||||
from .parent_layers import ConnectiveLayer, NeuralNetworkLayer
|
from .parent_layers import ConnectiveLayer, NeuralNetworkLayer
|
||||||
|
from .triplet import TripletLayer
|
||||||
|
from .triplet_to_feed_forward import TripletToFeedForward
|
@ -40,7 +40,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
|||||||
dots.append(dot)
|
dots.append(dot)
|
||||||
# Make the animation
|
# Make the animation
|
||||||
if self.passing_flash:
|
if self.passing_flash:
|
||||||
anim = ShowPassingFlash(edge.copy().set_color(self.animation_dot_color), time_width=0.2, run_time=3)
|
anim = ShowPassingFlash(edge.copy().set_color(self.animation_dot_color), time_width=0.2)
|
||||||
else:
|
else:
|
||||||
anim = MoveAlongPath(dot, edge, run_time=run_time, rate_function=sigmoid)
|
anim = MoveAlongPath(dot, edge, run_time=run_time, rate_function=sigmoid)
|
||||||
path_animations.append(anim)
|
path_animations.append(anim)
|
||||||
|
@ -29,7 +29,7 @@ class ImageToFeedForward(ConnectiveLayer):
|
|||||||
)
|
)
|
||||||
animations.append(per_node_succession)
|
animations.append(per_node_succession)
|
||||||
dots.append(new_dot)
|
dots.append(new_dot)
|
||||||
self.add(VGroup(*dots))
|
|
||||||
animation_group = AnimationGroup(*animations)
|
animation_group = AnimationGroup(*animations)
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
|
0
manim_ml/neural_network/layers/paired_query.py
Normal file
0
manim_ml/neural_network/layers/paired_query.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||||
|
|
||||||
|
class PairedQueryToFeedForward(ConnectiveLayer):
|
||||||
|
|
||||||
|
pass
|
96
manim_ml/neural_network/layers/triplet.py
Normal file
96
manim_ml/neural_network/layers/triplet.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers import NeuralNetworkLayer
|
||||||
|
from manim_ml.image import GrayscaleImageMobject
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class TripletLayer(NeuralNetworkLayer):
|
||||||
|
"""Shows triplet images"""
|
||||||
|
|
||||||
|
def __init__(self, anchor, positive, negative, stroke_width=5):
|
||||||
|
super().__init__()
|
||||||
|
self.anchor = anchor
|
||||||
|
self.positive = positive
|
||||||
|
self.negative = negative
|
||||||
|
|
||||||
|
self.stroke_width = stroke_width
|
||||||
|
# Make the assets
|
||||||
|
self.assets = self.make_assets()
|
||||||
|
self.add(self.assets)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_paths(cls, anchor_path, positive_path, negative_path, grayscale=True):
|
||||||
|
"""Creates a triplet using the anchor paths"""
|
||||||
|
# Load images from path
|
||||||
|
if grayscale:
|
||||||
|
anchor = GrayscaleImageMobject.from_path(anchor_path)
|
||||||
|
positive = GrayscaleImageMobject.from_path(positive_path)
|
||||||
|
negative = GrayscaleImageMobject.from_path(negative_path)
|
||||||
|
else:
|
||||||
|
anchor = ImageMobject(anchor_path)
|
||||||
|
positive = ImageMobject(positive_path)
|
||||||
|
negative = ImageMobject(negative_path)
|
||||||
|
# Make the layer
|
||||||
|
triplet_layer = cls(anchor, positive, negative)
|
||||||
|
|
||||||
|
return triplet_layer
|
||||||
|
|
||||||
|
def make_assets(self):
|
||||||
|
"""
|
||||||
|
Constructs the assets needed for a triplet layer
|
||||||
|
"""
|
||||||
|
# Handle anchor
|
||||||
|
anchor_text = Text("Anchor").scale(2)
|
||||||
|
anchor_text.next_to(self.anchor, UP, buff=1.0)
|
||||||
|
anchor_rectangle = SurroundingRectangle(
|
||||||
|
self.anchor,
|
||||||
|
color=WHITE,
|
||||||
|
buff=0.0,
|
||||||
|
stroke_width=self.stroke_width
|
||||||
|
)
|
||||||
|
anchor_group = Group(
|
||||||
|
anchor_text,
|
||||||
|
anchor_rectangle,
|
||||||
|
self.anchor,
|
||||||
|
)
|
||||||
|
# Handle positive
|
||||||
|
positive_text = Text("Positive").scale(2)
|
||||||
|
positive_text.next_to(self.positive, UP, buff=1.0)
|
||||||
|
positive_rectangle = SurroundingRectangle(
|
||||||
|
self.positive,
|
||||||
|
color=GREEN,
|
||||||
|
buff=0.0,
|
||||||
|
stroke_width=self.stroke_width
|
||||||
|
)
|
||||||
|
positive_group = Group(
|
||||||
|
positive_text,
|
||||||
|
positive_rectangle,
|
||||||
|
self.positive
|
||||||
|
)
|
||||||
|
# Handle negative
|
||||||
|
negative_text = Text("Negative").scale(2)
|
||||||
|
negative_text.next_to(self.negative, UP, buff=1.0)
|
||||||
|
negative_rectangle = SurroundingRectangle(
|
||||||
|
self.negative,
|
||||||
|
color=RED,
|
||||||
|
buff=0.0,
|
||||||
|
stroke_width=self.stroke_width
|
||||||
|
)
|
||||||
|
negative_group = Group(
|
||||||
|
negative_text,
|
||||||
|
negative_rectangle,
|
||||||
|
self.negative
|
||||||
|
)
|
||||||
|
# Distribute the groups uniformly vertically
|
||||||
|
assets = Group(anchor_group, positive_group, negative_group)
|
||||||
|
assets.arrange(DOWN, buff=1.5)
|
||||||
|
|
||||||
|
return assets
|
||||||
|
|
||||||
|
@override_animation(Create)
|
||||||
|
def _create_layer(self):
|
||||||
|
# TODO make Create animation that is custom
|
||||||
|
return FadeIn(self.assets)
|
||||||
|
|
||||||
|
def make_forward_pass_animation(self):
|
||||||
|
"""Forward pass for triplet"""
|
||||||
|
return AnimationGroup()
|
43
manim_ml/neural_network/layers/triplet_to_feed_forward.py
Normal file
43
manim_ml/neural_network/layers/triplet_to_feed_forward.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||||
|
|
||||||
|
class TripletToFeedForward(ConnectiveLayer):
|
||||||
|
"""TripletLayer to FeedForward layer"""
|
||||||
|
|
||||||
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||||
|
dot_radius=0.02):
|
||||||
|
self.animation_dot_color = animation_dot_color
|
||||||
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
|
self.feed_forward_layer = output_layer
|
||||||
|
self.triplet_layer = input_layer
|
||||||
|
super().__init__(input_layer, output_layer)
|
||||||
|
|
||||||
|
def make_forward_pass_animation(self):
|
||||||
|
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||||
|
animations = []
|
||||||
|
# Loop through each image
|
||||||
|
images = [self.triplet_layer.anchor, self.triplet_layer.positive, self.triplet_layer.negative]
|
||||||
|
for image_mobject in images:
|
||||||
|
image_animations = []
|
||||||
|
dots = []
|
||||||
|
# Move dots from each image to the centers of each of the nodes in the FeedForwardLayer
|
||||||
|
image_location = image_mobject.get_center()
|
||||||
|
for node in self.feed_forward_layer.node_group:
|
||||||
|
new_dot = Dot(image_location, radius=self.dot_radius, color=self.animation_dot_color)
|
||||||
|
per_node_succession = Succession(
|
||||||
|
Create(new_dot),
|
||||||
|
new_dot.animate.move_to(node.get_center()),
|
||||||
|
)
|
||||||
|
image_animations.append(per_node_succession)
|
||||||
|
dots.append(new_dot)
|
||||||
|
|
||||||
|
animations.append(AnimationGroup(*image_animations))
|
||||||
|
|
||||||
|
animation_group = AnimationGroup(*animations)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
|
||||||
|
@override_animation(Create)
|
||||||
|
def _create_override(self):
|
||||||
|
return AnimationGroup()
|
@ -16,8 +16,9 @@ import textwrap
|
|||||||
from manim_ml.neural_network.layers import \
|
from manim_ml.neural_network.layers import \
|
||||||
FeedForwardLayer, FeedForwardToFeedForward, ImageLayer, \
|
FeedForwardLayer, FeedForwardToFeedForward, ImageLayer, \
|
||||||
ImageToFeedForward, FeedForwardToImage, EmbeddingLayer, \
|
ImageToFeedForward, FeedForwardToImage, EmbeddingLayer, \
|
||||||
EmbeddingToFeedForward, FeedForwardToEmbedding
|
EmbeddingToFeedForward, FeedForwardToEmbedding, TripletLayer, \
|
||||||
|
TripletToFeedForward
|
||||||
|
|
||||||
class NeuralNetwork(Group):
|
class NeuralNetwork(Group):
|
||||||
|
|
||||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
||||||
@ -102,6 +103,12 @@ class NeuralNetwork(Group):
|
|||||||
animation_dot_color=self.animation_dot_color, dot_radius=self.dot_radius)
|
animation_dot_color=self.animation_dot_color, dot_radius=self.dot_radius)
|
||||||
connective_layers.add(layer)
|
connective_layers.add(layer)
|
||||||
all_layers.add(layer)
|
all_layers.add(layer)
|
||||||
|
elif isinstance(current_layer, TripletLayer) \
|
||||||
|
and isinstance(next_layer, FeedForwardLayer):
|
||||||
|
# TripletLayer to FeedForwardLayer
|
||||||
|
layer = TripletToFeedForward(current_layer, next_layer)
|
||||||
|
connective_layers.add(layer)
|
||||||
|
all_layers.add(layer)
|
||||||
else:
|
else:
|
||||||
warnings.warn(f"Warning: unimplemented connection for layer types: {type(current_layer)} and {type(next_layer)}")
|
warnings.warn(f"Warning: unimplemented connection for layer types: {type(current_layer)} and {type(next_layer)}")
|
||||||
# Add final layer
|
# Add final layer
|
||||||
|
32
tests/test_triplet.py
Normal file
32
tests/test_triplet.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers import TripletLayer, triplet
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
|
config.pixel_height = 720
|
||||||
|
config.pixel_width = 1280
|
||||||
|
config.frame_height = 6.0
|
||||||
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
class TripletScene(Scene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
anchor_path = "../assets/triplet/anchor.jpg"
|
||||||
|
positive_path = "../assets/triplet/positive.jpg"
|
||||||
|
negative_path = "../assets/triplet/negative.jpg"
|
||||||
|
|
||||||
|
triplet_layer = TripletLayer.from_paths(anchor_path, positive_path, negative_path, grayscale=False)
|
||||||
|
|
||||||
|
triplet_layer.scale(0.08)
|
||||||
|
|
||||||
|
neural_network = NeuralNetwork([
|
||||||
|
triplet_layer,
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
FeedForwardLayer(3)
|
||||||
|
])
|
||||||
|
|
||||||
|
neural_network.scale(1)
|
||||||
|
|
||||||
|
self.play(Create(neural_network), run_time=3)
|
||||||
|
|
||||||
|
self.play(neural_network.make_forward_pass_animation(), run_time=10)
|
Reference in New Issue
Block a user