mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-25 00:40:54 +08:00
31 lines
1012 B
Python
31 lines
1012 B
Python
import warnings
|
|
|
|
from manim import *
|
|
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
|
|
from manim_ml.neural_network.layers import connective_layers_list
|
|
|
|
|
|
def get_connective_layer(input_layer, output_layer):
|
|
"""
|
|
Deduces the relevant connective layer
|
|
"""
|
|
connective_layer_class = None
|
|
for candidate_class in connective_layers_list:
|
|
input_class = candidate_class.input_class
|
|
output_class = candidate_class.output_class
|
|
if isinstance(input_layer, input_class) and isinstance(
|
|
output_layer, output_class
|
|
):
|
|
connective_layer_class = candidate_class
|
|
break
|
|
|
|
if connective_layer_class is None:
|
|
connective_layer_class = BlankConnective
|
|
warnings.warn(
|
|
f"Unrecognized input/output class pair: {input_class} and {output_class}"
|
|
)
|
|
# Make the instance now
|
|
connective_layer = connective_layer_class(input_layer, output_layer)
|
|
|
|
return connective_layer
|