diff --git a/manim_ml/image.py b/manim_ml/image.py index b26e28e..64ce0e6 100644 --- a/manim_ml/image.py +++ b/manim_ml/image.py @@ -30,3 +30,26 @@ class GrayscaleImageMobject(ImageMobject): @override_animation(Create) def create(self, run_time=2): return FadeIn(self) + +class LabeledColorImage(Group): + """Labeled Color Image""" + + def __init__(self, image, color=RED, label="Positive", stroke_width=5): + super().__init__() + self.image = image + self.color = color + self.label = label + self.stroke_width = stroke_width + + text = Text(label).scale(2) + text.next_to(self.image, UP, buff=1.0) + rectangle = SurroundingRectangle( + self.image, + color=color, + buff=0.0, + stroke_width=self.stroke_width + ) + + self.add(text) + self.add(rectangle) + self.add(self.image) \ No newline at end of file diff --git a/manim_ml/neural_network/layers/paired_query.py b/manim_ml/neural_network/layers/paired_query.py index e2b1ac2..add053e 100644 --- a/manim_ml/neural_network/layers/paired_query.py +++ b/manim_ml/neural_network/layers/paired_query.py @@ -1,6 +1,6 @@ from manim import * from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer -from manim_ml.image import GrayscaleImageMobject +from manim_ml.image import GrayscaleImageMobject, LabeledColorImage import numpy as np class PairedQueryLayer(NeuralNetworkLayer): @@ -36,33 +36,19 @@ class PairedQueryLayer(NeuralNetworkLayer): Constructs the assets needed for a query layer """ # Handle positive - positive_text = Text("Positive").scale(2) - positive_text.next_to(self.positive, UP, buff=1.0) - positive_rectangle = SurroundingRectangle( + positive_group = LabeledColorImage( self.positive, color=GREEN, - buff=0.0, + label="Positive", stroke_width=self.stroke_width ) - positive_group = Group( - positive_text, - positive_rectangle, - self.positive - ) # Handle negative - negative_text = Text("Negative").scale(2) - negative_text.next_to(self.negative, UP, buff=1.0) - negative_rectangle = SurroundingRectangle( + negative_group = LabeledColorImage( self.negative, color=RED, - buff=0.0, + label="Negative", stroke_width=self.stroke_width ) - negative_group = Group( - negative_text, - negative_rectangle, - self.negative - ) # Distribute the groups uniformly vertically assets = Group(positive_group, negative_group) assets.arrange(DOWN, buff=1.5) diff --git a/manim_ml/neural_network/layers/triplet.py b/manim_ml/neural_network/layers/triplet.py index 9042173..1415bad 100644 --- a/manim_ml/neural_network/layers/triplet.py +++ b/manim_ml/neural_network/layers/triplet.py @@ -1,6 +1,6 @@ from manim import * from manim_ml.neural_network.layers import NeuralNetworkLayer -from manim_ml.image import GrayscaleImageMobject +from manim_ml.image import GrayscaleImageMobject, LabeledColorImage import numpy as np class TripletLayer(NeuralNetworkLayer): @@ -39,47 +39,26 @@ class TripletLayer(NeuralNetworkLayer): Constructs the assets needed for a triplet layer """ # Handle anchor - anchor_text = Text("Anchor").scale(2) - anchor_text.next_to(self.anchor, UP, buff=1.0) - anchor_rectangle = SurroundingRectangle( + anchor_group = LabeledColorImage( self.anchor, color=WHITE, - buff=0.0, + label="Anchor", stroke_width=self.stroke_width ) - anchor_group = Group( - anchor_text, - anchor_rectangle, - self.anchor, - ) # Handle positive - positive_text = Text("Positive").scale(2) - positive_text.next_to(self.positive, UP, buff=1.0) - positive_rectangle = SurroundingRectangle( + positive_group = LabeledColorImage( self.positive, color=GREEN, - buff=0.0, + label="Positive", stroke_width=self.stroke_width ) - positive_group = Group( - positive_text, - positive_rectangle, - self.positive - ) # Handle negative - negative_text = Text("Negative").scale(2) - negative_text.next_to(self.negative, UP, buff=1.0) - negative_rectangle = SurroundingRectangle( + negative_group = LabeledColorImage( self.negative, color=RED, - buff=0.0, + label="Negative", stroke_width=self.stroke_width ) - negative_group = Group( - negative_text, - negative_rectangle, - self.negative - ) # Distribute the groups uniformly vertically assets = Group(anchor_group, positive_group, negative_group) assets.arrange(DOWN, buff=1.5)