mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 17:29:45 +08:00
Working forward pass for triplet layers.
This commit is contained in:
32
tests/test_triplet.py
Normal file
32
tests/test_triplet.py
Normal file
@ -0,0 +1,32 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers import TripletLayer, triplet
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
class TripletScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
anchor_path = "../assets/triplet/anchor.jpg"
|
||||
positive_path = "../assets/triplet/positive.jpg"
|
||||
negative_path = "../assets/triplet/negative.jpg"
|
||||
|
||||
triplet_layer = TripletLayer.from_paths(anchor_path, positive_path, negative_path, grayscale=False)
|
||||
|
||||
triplet_layer.scale(0.08)
|
||||
|
||||
neural_network = NeuralNetwork([
|
||||
triplet_layer,
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3)
|
||||
])
|
||||
|
||||
neural_network.scale(1)
|
||||
|
||||
self.play(Create(neural_network), run_time=3)
|
||||
|
||||
self.play(neural_network.make_forward_pass_animation(), run_time=10)
|
Reference in New Issue
Block a user