Files
2022-12-29 14:09:16 -05:00

28 lines
964 B
Python

import warnings
from manim import *
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
from ..layers import connective_layers_list
def get_connective_layer(input_layer, output_layer):
"""
Deduces the relevant connective layer
"""
connective_layer_class = None
for candidate_class in connective_layers_list:
input_class = candidate_class.input_class
output_class = candidate_class.output_class
if isinstance(input_layer, input_class) \
and isinstance(output_layer, output_class):
connective_layer_class = candidate_class
break
if connective_layer_class is None:
connective_layer_class = BlankConnective
warnings.warn(f"Unrecognized input/output class pair: {input_class} and {output_class}")
# Make the instance now
connective_layer = connective_layer_class(input_layer, output_layer)
return connective_layer