From 0febbe547d0ff0caf1b3960cfb06be1eb4942d31 Mon Sep 17 00:00:00 2001
From: Alec Helbling <alechelbling1@gmail.com>
Date: Sat, 16 Apr 2022 00:37:19 -0400
Subject: [PATCH] Refactored LabeledColorImage

---
 manim_ml/image.py                             | 23 ++++++++++++
 .../neural_network/layers/paired_query.py     | 24 +++----------
 manim_ml/neural_network/layers/triplet.py     | 35 ++++---------------
 3 files changed, 35 insertions(+), 47 deletions(-)

diff --git a/manim_ml/image.py b/manim_ml/image.py
index b26e28e..64ce0e6 100644
--- a/manim_ml/image.py
+++ b/manim_ml/image.py
@@ -30,3 +30,26 @@ class GrayscaleImageMobject(ImageMobject):
     @override_animation(Create)
     def create(self, run_time=2):
         return FadeIn(self)
+
+class LabeledColorImage(Group):
+    """Labeled Color Image"""
+
+    def __init__(self, image, color=RED, label="Positive", stroke_width=5):
+        super().__init__()
+        self.image = image
+        self.color = color
+        self.label = label
+        self.stroke_width = stroke_width
+
+        text = Text(label).scale(2)
+        text.next_to(self.image, UP, buff=1.0)
+        rectangle = SurroundingRectangle(
+            self.image, 
+            color=color,
+            buff=0.0,
+            stroke_width=self.stroke_width
+        )
+
+        self.add(text)
+        self.add(rectangle)
+        self.add(self.image)
\ No newline at end of file
diff --git a/manim_ml/neural_network/layers/paired_query.py b/manim_ml/neural_network/layers/paired_query.py
index e2b1ac2..add053e 100644
--- a/manim_ml/neural_network/layers/paired_query.py
+++ b/manim_ml/neural_network/layers/paired_query.py
@@ -1,6 +1,6 @@
 from manim import *
 from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
-from manim_ml.image import GrayscaleImageMobject
+from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
 import numpy as np
 
 class PairedQueryLayer(NeuralNetworkLayer):
@@ -36,33 +36,19 @@ class PairedQueryLayer(NeuralNetworkLayer):
             Constructs the assets needed for a query layer
         """
         # Handle positive
-        positive_text = Text("Positive").scale(2)
-        positive_text.next_to(self.positive, UP, buff=1.0)
-        positive_rectangle = SurroundingRectangle(
+        positive_group = LabeledColorImage(
             self.positive, 
             color=GREEN,
-            buff=0.0,
+            label="Positive", 
             stroke_width=self.stroke_width
         )
-        positive_group = Group(
-            positive_text,
-            positive_rectangle,
-            self.positive
-        )
         # Handle negative
-        negative_text = Text("Negative").scale(2)
-        negative_text.next_to(self.negative, UP, buff=1.0)
-        negative_rectangle = SurroundingRectangle(
+        negative_group = LabeledColorImage(
             self.negative, 
             color=RED,
-            buff=0.0,
+            label="Negative", 
             stroke_width=self.stroke_width
         )
-        negative_group = Group(
-            negative_text,
-            negative_rectangle,
-            self.negative
-        )
         # Distribute the groups uniformly vertically
         assets = Group(positive_group, negative_group)
         assets.arrange(DOWN, buff=1.5)
diff --git a/manim_ml/neural_network/layers/triplet.py b/manim_ml/neural_network/layers/triplet.py
index 9042173..1415bad 100644
--- a/manim_ml/neural_network/layers/triplet.py
+++ b/manim_ml/neural_network/layers/triplet.py
@@ -1,6 +1,6 @@
 from manim import *
 from manim_ml.neural_network.layers import NeuralNetworkLayer
-from manim_ml.image import GrayscaleImageMobject
+from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
 import numpy as np
 
 class TripletLayer(NeuralNetworkLayer):
@@ -39,47 +39,26 @@ class TripletLayer(NeuralNetworkLayer):
             Constructs the assets needed for a triplet layer
         """
         # Handle anchor
-        anchor_text = Text("Anchor").scale(2)
-        anchor_text.next_to(self.anchor, UP, buff=1.0)
-        anchor_rectangle = SurroundingRectangle(
+        anchor_group = LabeledColorImage(
             self.anchor, 
             color=WHITE,
-            buff=0.0,
+            label="Anchor", 
             stroke_width=self.stroke_width
         )
-        anchor_group = Group(
-            anchor_text,
-            anchor_rectangle,
-            self.anchor,
-        )
         # Handle positive
-        positive_text = Text("Positive").scale(2)
-        positive_text.next_to(self.positive, UP, buff=1.0)
-        positive_rectangle = SurroundingRectangle(
+        positive_group = LabeledColorImage(
             self.positive, 
             color=GREEN,
-            buff=0.0,
+            label="Positive", 
             stroke_width=self.stroke_width
         )
-        positive_group = Group(
-            positive_text,
-            positive_rectangle,
-            self.positive
-        )
         # Handle negative
-        negative_text = Text("Negative").scale(2)
-        negative_text.next_to(self.negative, UP, buff=1.0)
-        negative_rectangle = SurroundingRectangle(
+        negative_group = LabeledColorImage(
             self.negative, 
             color=RED,
-            buff=0.0,
+            label="Negative", 
             stroke_width=self.stroke_width
         )
-        negative_group = Group(
-            negative_text,
-            negative_rectangle,
-            self.negative
-        )
         # Distribute the groups uniformly vertically
         assets = Group(anchor_group, positive_group, negative_group)
         assets.arrange(DOWN, buff=1.5)