mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-26 04:33:17 +08:00
Refactored some code in neural network, added max pooling to feed forward
This commit is contained in:
@ -10,6 +10,7 @@ from manim_ml.neural_network.layers.image_to_convolutional_2d import (
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
|
||||
MaxPooling2DToConvolutional2D,
|
||||
)
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_feed_forward import MaxPooling2DToFeedForward
|
||||
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
|
||||
from .convolutional_2d import Convolutional2DLayer
|
||||
from .feed_forward_to_vector import FeedForwardToVector
|
||||
@ -40,9 +41,9 @@ connective_layers_list = (
|
||||
PairedQueryToFeedForward,
|
||||
FeedForwardToVector,
|
||||
Convolutional2DToConvolutional2D,
|
||||
Convolutional2DToConvolutional2D,
|
||||
ImageToConvolutional2DLayer,
|
||||
Convolutional2DToFeedForward,
|
||||
Convolutional2DToMaxPooling2D,
|
||||
MaxPooling2DToConvolutional2D,
|
||||
MaxPooling2DToFeedForward,
|
||||
)
|
||||
|
@ -0,0 +1,27 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import Convolutional2DToFeedForward
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
|
||||
class MaxPooling2DToFeedForward(Convolutional2DToFeedForward):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
input_class = MaxPooling2DLayer
|
||||
output_class = FeedForwardLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_layer: MaxPooling2DLayer,
|
||||
output_layer: FeedForwardLayer,
|
||||
passing_flash_color=ORANGE,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
@ -4,7 +4,6 @@ from manim import *
|
||||
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
|
||||
from manim_ml.neural_network.layers import connective_layers_list
|
||||
|
||||
|
||||
def get_connective_layer(input_layer, output_layer):
|
||||
"""
|
||||
Deduces the relevant connective layer
|
||||
@ -13,16 +12,15 @@ def get_connective_layer(input_layer, output_layer):
|
||||
for candidate_class in connective_layers_list:
|
||||
input_class = candidate_class.input_class
|
||||
output_class = candidate_class.output_class
|
||||
if isinstance(input_layer, input_class) and isinstance(
|
||||
output_layer, output_class
|
||||
):
|
||||
if isinstance(input_layer, input_class) and \
|
||||
isinstance(output_layer, output_class):
|
||||
connective_layer_class = candidate_class
|
||||
break
|
||||
|
||||
if connective_layer_class is None:
|
||||
connective_layer_class = BlankConnective
|
||||
warnings.warn(
|
||||
f"Unrecognized input/output class pair: {input_class} and {output_class}"
|
||||
f"Unrecognized input/output class pair: {input_layer} and {output_layer}"
|
||||
)
|
||||
# Make the instance now
|
||||
connective_layer = connective_layer_class(input_layer, output_layer)
|
||||
|
Reference in New Issue
Block a user