mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-22 04:56:26 +08:00
86 lines
2.8 KiB
Python
86 lines
2.8 KiB
Python
from manim import *
|
|
from manim_ml.neural_network.layers import NeuralNetworkLayer
|
|
from manim_ml.image import GrayscaleImageMobject, LabeledColorImage
|
|
import numpy as np
|
|
|
|
class TripletLayer(NeuralNetworkLayer):
|
|
"""Shows triplet images"""
|
|
|
|
def __init__(self, anchor, positive, negative, stroke_width=5,
|
|
font_size=22, buff=0.2, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.anchor = anchor
|
|
self.positive = positive
|
|
self.negative = negative
|
|
self.buff = buff
|
|
|
|
self.stroke_width = stroke_width
|
|
self.font_size = font_size
|
|
# Make the assets
|
|
self.assets = self.make_assets()
|
|
self.add(self.assets)
|
|
|
|
@classmethod
|
|
def from_paths(cls, anchor_path, positive_path, negative_path, grayscale=True,
|
|
font_size=22, buff=0.2):
|
|
"""Creates a triplet using the anchor paths"""
|
|
# Load images from path
|
|
if grayscale:
|
|
anchor = GrayscaleImageMobject.from_path(anchor_path)
|
|
positive = GrayscaleImageMobject.from_path(positive_path)
|
|
negative = GrayscaleImageMobject.from_path(negative_path)
|
|
else:
|
|
anchor = ImageMobject(anchor_path)
|
|
positive = ImageMobject(positive_path)
|
|
negative = ImageMobject(negative_path)
|
|
# Make the layer
|
|
triplet_layer = cls(anchor, positive, negative, font_size=font_size, buff=buff)
|
|
|
|
return triplet_layer
|
|
|
|
def make_assets(self):
|
|
"""
|
|
Constructs the assets needed for a triplet layer
|
|
"""
|
|
# Handle anchor
|
|
anchor_group = LabeledColorImage(
|
|
self.anchor,
|
|
color=WHITE,
|
|
label="Anchor",
|
|
stroke_width=self.stroke_width,
|
|
font_size=self.font_size,
|
|
buff=self.buff
|
|
)
|
|
# Handle positive
|
|
positive_group = LabeledColorImage(
|
|
self.positive,
|
|
color=GREEN,
|
|
label="Positive",
|
|
stroke_width=self.stroke_width,
|
|
font_size=self.font_size,
|
|
buff=self.buff
|
|
)
|
|
# Handle negative
|
|
negative_group = LabeledColorImage(
|
|
self.negative,
|
|
color=RED,
|
|
label="Negative",
|
|
stroke_width=self.stroke_width,
|
|
font_size=self.font_size,
|
|
buff=self.buff
|
|
)
|
|
# Distribute the groups uniformly vertically
|
|
assets = Group(anchor_group, positive_group, negative_group)
|
|
assets.arrange(DOWN, buff=1.5)
|
|
|
|
return assets
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self):
|
|
# TODO make Create animation that is custom
|
|
return FadeIn(self.assets)
|
|
|
|
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
|
"""Forward pass for triplet"""
|
|
return AnimationGroup()
|