Working initial visualization of a CNN.

This commit is contained in:
Alec Helbling
2022-12-29 14:09:16 -05:00
parent 330ba170a0
commit 8cee86e884
18 changed files with 384 additions and 236 deletions

View File

@ -1,25 +1,27 @@
from manim import *
import warnings
from manim_ml.neural_network.layers.parent_layers import BlankConnective
from manim import *
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
from ..layers import connective_layers_list
def get_connective_layer(input_layer, output_layer):
"""
Deduces the relevant connective layer
"""
connective_layer = None
for connective_layer_class in connective_layers_list:
input_class = connective_layer_class.input_class
output_class = connective_layer_class.output_class
connective_layer_class = None
for candidate_class in connective_layers_list:
input_class = candidate_class.input_class
output_class = candidate_class.output_class
if isinstance(input_layer, input_class) \
and isinstance(output_layer, output_class):
connective_layer = connective_layer_class(input_layer, output_layer)
connective_layer_class = candidate_class
break
if connective_layer is None:
connective_layer = BlankConnective(input_layer, output_layer)
"""
raise Exception(f"Unrecognized class pair {input_layer.__class__.__name__}" + \
f" and {output_layer.__class__.__name__}")
"""
if connective_layer_class is None:
connective_layer_class = BlankConnective
warnings.warn(f"Unrecognized input/output class pair: {input_class} and {output_class}")
# Make the instance now
connective_layer = connective_layer_class(input_layer, output_layer)
return connective_layer