mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-17 18:55:54 +08:00
Refactored some code in neural network, added max pooling to feed forward
This commit is contained in:
@ -10,6 +10,7 @@ from manim_ml.neural_network.layers.image_to_convolutional_2d import (
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
|
||||
MaxPooling2DToConvolutional2D,
|
||||
)
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_feed_forward import MaxPooling2DToFeedForward
|
||||
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
|
||||
from .convolutional_2d import Convolutional2DLayer
|
||||
from .feed_forward_to_vector import FeedForwardToVector
|
||||
@ -40,9 +41,9 @@ connective_layers_list = (
|
||||
PairedQueryToFeedForward,
|
||||
FeedForwardToVector,
|
||||
Convolutional2DToConvolutional2D,
|
||||
Convolutional2DToConvolutional2D,
|
||||
ImageToConvolutional2DLayer,
|
||||
Convolutional2DToFeedForward,
|
||||
Convolutional2DToMaxPooling2D,
|
||||
MaxPooling2DToConvolutional2D,
|
||||
MaxPooling2DToFeedForward,
|
||||
)
|
||||
|
@ -0,0 +1,27 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import Convolutional2DToFeedForward
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
|
||||
class MaxPooling2DToFeedForward(Convolutional2DToFeedForward):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
input_class = MaxPooling2DLayer
|
||||
output_class = FeedForwardLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_layer: MaxPooling2DLayer,
|
||||
output_layer: FeedForwardLayer,
|
||||
passing_flash_color=ORANGE,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
@ -4,7 +4,6 @@ from manim import *
|
||||
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
|
||||
from manim_ml.neural_network.layers import connective_layers_list
|
||||
|
||||
|
||||
def get_connective_layer(input_layer, output_layer):
|
||||
"""
|
||||
Deduces the relevant connective layer
|
||||
@ -13,16 +12,15 @@ def get_connective_layer(input_layer, output_layer):
|
||||
for candidate_class in connective_layers_list:
|
||||
input_class = candidate_class.input_class
|
||||
output_class = candidate_class.output_class
|
||||
if isinstance(input_layer, input_class) and isinstance(
|
||||
output_layer, output_class
|
||||
):
|
||||
if isinstance(input_layer, input_class) and \
|
||||
isinstance(output_layer, output_class):
|
||||
connective_layer_class = candidate_class
|
||||
break
|
||||
|
||||
if connective_layer_class is None:
|
||||
connective_layer_class = BlankConnective
|
||||
warnings.warn(
|
||||
f"Unrecognized input/output class pair: {input_class} and {output_class}"
|
||||
f"Unrecognized input/output class pair: {input_layer} and {output_layer}"
|
||||
)
|
||||
# Make the instance now
|
||||
connective_layer = connective_layer_class(input_layer, output_layer)
|
||||
|
@ -10,6 +10,7 @@ Example:
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
import textwrap
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
@ -95,63 +96,18 @@ class NeuralNetwork(Group):
|
||||
previous_layer = self.input_layers[layer_index - 1]
|
||||
current_layer = self.input_layers[layer_index]
|
||||
current_layer.move_to(previous_layer.get_center())
|
||||
# TODO Temp fix
|
||||
if isinstance(current_layer, EmbeddingLayer) or isinstance(
|
||||
previous_layer, EmbeddingLayer
|
||||
):
|
||||
if layout_direction == "left_to_right":
|
||||
shift_vector = np.array(
|
||||
[
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
0,
|
||||
]
|
||||
)
|
||||
elif layout_direction == "top_to_bottom":
|
||||
shift_vector = np.array(
|
||||
[
|
||||
0,
|
||||
-(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
else:
|
||||
if layout_direction == "left_to_right":
|
||||
shift_vector = np.array(
|
||||
[
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
+ self.layer_spacing,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
)
|
||||
elif layout_direction == "top_to_bottom":
|
||||
shift_vector = np.array(
|
||||
[
|
||||
0,
|
||||
-(
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
x_shift = previous_layer.get_width() / 2 \
|
||||
+ current_layer.get_width() / 2 \
|
||||
+ self.layer_spacing
|
||||
),
|
||||
0,
|
||||
]
|
||||
)
|
||||
shift_vector = np.array([x_shift, 0, 0])
|
||||
elif layout_direction == "top_to_bottom":
|
||||
y_shift = -((
|
||||
previous_layer.get_width() / 2 \
|
||||
+ current_layer.get_width() / 2
|
||||
) + self.layer_spacing)
|
||||
|
||||
shift_vector = np.array([0, y_shift, 0])
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
|
Reference in New Issue
Block a user