mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-21 04:26:43 +08:00
45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from manim import *
|
|
from abc import ABC, abstractmethod
|
|
|
|
class NeuralNetworkLayer(ABC, Group):
|
|
"""Abstract Neural Network Layer class"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super(Group, self).__init__()
|
|
self.set_z_index(1)
|
|
|
|
@abstractmethod
|
|
def make_forward_pass_animation(self):
|
|
pass
|
|
|
|
def __repr__(self):
|
|
return f"{type(self).__name__}"
|
|
|
|
class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
|
|
|
def __init__(self, **kwargs):
|
|
super(NeuralNetworkLayer, self).__init__(**kwargs)
|
|
|
|
@abstractmethod
|
|
def make_forward_pass_animation(self):
|
|
pass
|
|
|
|
class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
|
"""Forward pass animation for a given pair of layers"""
|
|
|
|
@abstractmethod
|
|
def __init__(self, input_layer, output_layer, input_class=None, output_class=None):
|
|
super(VGroupNeuralNetworkLayer, self).__init__()
|
|
self.input_layer = input_layer
|
|
self.output_layer = output_layer
|
|
self.input_class = input_class
|
|
self.output_class = output_class
|
|
# Handle input and output class
|
|
assert isinstance(input_layer, self.input_class)
|
|
assert isinstance(output_layer, self.output_class)
|
|
|
|
self.set_z_index(-1)
|
|
|
|
@abstractmethod
|
|
def make_forward_pass_animation(self):
|
|
pass |