Refactored LabeledColorImage

This commit is contained in:
Alec Helbling
2022-04-16 00:37:19 -04:00
parent 2306ab39d1
commit 0febbe547d
3 changed files with 35 additions and 47 deletions

View File

@ -30,3 +30,26 @@ class GrayscaleImageMobject(ImageMobject):
@override_animation(Create) @override_animation(Create)
def create(self, run_time=2): def create(self, run_time=2):
return FadeIn(self) return FadeIn(self)
class LabeledColorImage(Group):
"""Labeled Color Image"""
def __init__(self, image, color=RED, label="Positive", stroke_width=5):
super().__init__()
self.image = image
self.color = color
self.label = label
self.stroke_width = stroke_width
text = Text(label).scale(2)
text.next_to(self.image, UP, buff=1.0)
rectangle = SurroundingRectangle(
self.image,
color=color,
buff=0.0,
stroke_width=self.stroke_width
)
self.add(text)
self.add(rectangle)
self.add(self.image)

View File

@ -1,6 +1,6 @@
from manim import * from manim import *
from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
from manim_ml.image import GrayscaleImageMobject from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
import numpy as np import numpy as np
class PairedQueryLayer(NeuralNetworkLayer): class PairedQueryLayer(NeuralNetworkLayer):
@ -36,33 +36,19 @@ class PairedQueryLayer(NeuralNetworkLayer):
Constructs the assets needed for a query layer Constructs the assets needed for a query layer
""" """
# Handle positive # Handle positive
positive_text = Text("Positive").scale(2) positive_group = LabeledColorImage(
positive_text.next_to(self.positive, UP, buff=1.0)
positive_rectangle = SurroundingRectangle(
self.positive, self.positive,
color=GREEN, color=GREEN,
buff=0.0, label="Positive",
stroke_width=self.stroke_width stroke_width=self.stroke_width
) )
positive_group = Group(
positive_text,
positive_rectangle,
self.positive
)
# Handle negative # Handle negative
negative_text = Text("Negative").scale(2) negative_group = LabeledColorImage(
negative_text.next_to(self.negative, UP, buff=1.0)
negative_rectangle = SurroundingRectangle(
self.negative, self.negative,
color=RED, color=RED,
buff=0.0, label="Negative",
stroke_width=self.stroke_width stroke_width=self.stroke_width
) )
negative_group = Group(
negative_text,
negative_rectangle,
self.negative
)
# Distribute the groups uniformly vertically # Distribute the groups uniformly vertically
assets = Group(positive_group, negative_group) assets = Group(positive_group, negative_group)
assets.arrange(DOWN, buff=1.5) assets.arrange(DOWN, buff=1.5)

View File

@ -1,6 +1,6 @@
from manim import * from manim import *
from manim_ml.neural_network.layers import NeuralNetworkLayer from manim_ml.neural_network.layers import NeuralNetworkLayer
from manim_ml.image import GrayscaleImageMobject from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
import numpy as np import numpy as np
class TripletLayer(NeuralNetworkLayer): class TripletLayer(NeuralNetworkLayer):
@ -39,47 +39,26 @@ class TripletLayer(NeuralNetworkLayer):
Constructs the assets needed for a triplet layer Constructs the assets needed for a triplet layer
""" """
# Handle anchor # Handle anchor
anchor_text = Text("Anchor").scale(2) anchor_group = LabeledColorImage(
anchor_text.next_to(self.anchor, UP, buff=1.0)
anchor_rectangle = SurroundingRectangle(
self.anchor, self.anchor,
color=WHITE, color=WHITE,
buff=0.0, label="Anchor",
stroke_width=self.stroke_width stroke_width=self.stroke_width
) )
anchor_group = Group(
anchor_text,
anchor_rectangle,
self.anchor,
)
# Handle positive # Handle positive
positive_text = Text("Positive").scale(2) positive_group = LabeledColorImage(
positive_text.next_to(self.positive, UP, buff=1.0)
positive_rectangle = SurroundingRectangle(
self.positive, self.positive,
color=GREEN, color=GREEN,
buff=0.0, label="Positive",
stroke_width=self.stroke_width stroke_width=self.stroke_width
) )
positive_group = Group(
positive_text,
positive_rectangle,
self.positive
)
# Handle negative # Handle negative
negative_text = Text("Negative").scale(2) negative_group = LabeledColorImage(
negative_text.next_to(self.negative, UP, buff=1.0)
negative_rectangle = SurroundingRectangle(
self.negative, self.negative,
color=RED, color=RED,
buff=0.0, label="Negative",
stroke_width=self.stroke_width stroke_width=self.stroke_width
) )
negative_group = Group(
negative_text,
negative_rectangle,
self.negative
)
# Distribute the groups uniformly vertically # Distribute the groups uniformly vertically
assets = Group(anchor_group, positive_group, negative_group) assets = Group(anchor_group, positive_group, negative_group)
assets.arrange(DOWN, buff=1.5) assets.arrange(DOWN, buff=1.5)