Refactored LabeledColorImage

This commit is contained in:
Alec Helbling
2022-04-16 00:37:19 -04:00
parent 2306ab39d1
commit 0febbe547d
3 changed files with 35 additions and 47 deletions

View File

@ -1,6 +1,6 @@
from manim import *
from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
from manim_ml.image import GrayscaleImageMobject
from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
import numpy as np
class PairedQueryLayer(NeuralNetworkLayer):
@ -36,33 +36,19 @@ class PairedQueryLayer(NeuralNetworkLayer):
Constructs the assets needed for a query layer
"""
# Handle positive
positive_text = Text("Positive").scale(2)
positive_text.next_to(self.positive, UP, buff=1.0)
positive_rectangle = SurroundingRectangle(
positive_group = LabeledColorImage(
self.positive,
color=GREEN,
buff=0.0,
label="Positive",
stroke_width=self.stroke_width
)
positive_group = Group(
positive_text,
positive_rectangle,
self.positive
)
# Handle negative
negative_text = Text("Negative").scale(2)
negative_text.next_to(self.negative, UP, buff=1.0)
negative_rectangle = SurroundingRectangle(
negative_group = LabeledColorImage(
self.negative,
color=RED,
buff=0.0,
label="Negative",
stroke_width=self.stroke_width
)
negative_group = Group(
negative_text,
negative_rectangle,
self.negative
)
# Distribute the groups uniformly vertically
assets = Group(positive_group, negative_group)
assets.arrange(DOWN, buff=1.5)

View File

@ -1,6 +1,6 @@
from manim import *
from manim_ml.neural_network.layers import NeuralNetworkLayer
from manim_ml.image import GrayscaleImageMobject
from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
import numpy as np
class TripletLayer(NeuralNetworkLayer):
@ -39,47 +39,26 @@ class TripletLayer(NeuralNetworkLayer):
Constructs the assets needed for a triplet layer
"""
# Handle anchor
anchor_text = Text("Anchor").scale(2)
anchor_text.next_to(self.anchor, UP, buff=1.0)
anchor_rectangle = SurroundingRectangle(
anchor_group = LabeledColorImage(
self.anchor,
color=WHITE,
buff=0.0,
label="Anchor",
stroke_width=self.stroke_width
)
anchor_group = Group(
anchor_text,
anchor_rectangle,
self.anchor,
)
# Handle positive
positive_text = Text("Positive").scale(2)
positive_text.next_to(self.positive, UP, buff=1.0)
positive_rectangle = SurroundingRectangle(
positive_group = LabeledColorImage(
self.positive,
color=GREEN,
buff=0.0,
label="Positive",
stroke_width=self.stroke_width
)
positive_group = Group(
positive_text,
positive_rectangle,
self.positive
)
# Handle negative
negative_text = Text("Negative").scale(2)
negative_text.next_to(self.negative, UP, buff=1.0)
negative_rectangle = SurroundingRectangle(
negative_group = LabeledColorImage(
self.negative,
color=RED,
buff=0.0,
label="Negative",
stroke_width=self.stroke_width
)
negative_group = Group(
negative_text,
negative_rectangle,
self.negative
)
# Distribute the groups uniformly vertically
assets = Group(anchor_group, positive_group, negative_group)
assets.arrange(DOWN, buff=1.5)