Refactored some code in neural network, added max pooling to feed forward

This commit is contained in:
Alec Helbling
2023-01-31 10:30:49 -05:00
parent ae6fd8a230
commit c14972fa4b
5 changed files with 49 additions and 67 deletions

View File

@ -10,6 +10,7 @@ from manim_ml.neural_network.layers.image_to_convolutional_2d import (
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
MaxPooling2DToConvolutional2D,
)
from manim_ml.neural_network.layers.max_pooling_2d_to_feed_forward import MaxPooling2DToFeedForward
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
from .convolutional_2d import Convolutional2DLayer
from .feed_forward_to_vector import FeedForwardToVector
@ -40,9 +41,9 @@ connective_layers_list = (
PairedQueryToFeedForward,
FeedForwardToVector,
Convolutional2DToConvolutional2D,
Convolutional2DToConvolutional2D,
ImageToConvolutional2DLayer,
Convolutional2DToFeedForward,
Convolutional2DToMaxPooling2D,
MaxPooling2DToConvolutional2D,
MaxPooling2DToFeedForward,
)

View File

@ -0,0 +1,27 @@
from manim import *
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import Convolutional2DToFeedForward
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
class MaxPooling2DToFeedForward(Convolutional2DToFeedForward):
"""Feed Forward to Embedding Layer"""
input_class = MaxPooling2DLayer
output_class = FeedForwardLayer
def __init__(
self,
input_layer: MaxPooling2DLayer,
output_layer: FeedForwardLayer,
passing_flash_color=ORANGE,
**kwargs
):
super().__init__(input_layer, output_layer, **kwargs)
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)

View File

@ -4,7 +4,6 @@ from manim import *
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
from manim_ml.neural_network.layers import connective_layers_list
def get_connective_layer(input_layer, output_layer):
"""
Deduces the relevant connective layer
@ -13,16 +12,15 @@ def get_connective_layer(input_layer, output_layer):
for candidate_class in connective_layers_list:
input_class = candidate_class.input_class
output_class = candidate_class.output_class
if isinstance(input_layer, input_class) and isinstance(
output_layer, output_class
):
if isinstance(input_layer, input_class) and \
isinstance(output_layer, output_class):
connective_layer_class = candidate_class
break
if connective_layer_class is None:
connective_layer_class = BlankConnective
warnings.warn(
f"Unrecognized input/output class pair: {input_class} and {output_class}"
f"Unrecognized input/output class pair: {input_layer} and {output_layer}"
)
# Make the instance now
connective_layer = connective_layer_class(input_layer, output_layer)