mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 03:57:40 +08:00
Refactored some code in neural network, added max pooling to feed forward
This commit is contained in:
@ -10,6 +10,7 @@ from manim_ml.neural_network.layers.image_to_convolutional_2d import (
|
|||||||
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
|
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
|
||||||
MaxPooling2DToConvolutional2D,
|
MaxPooling2DToConvolutional2D,
|
||||||
)
|
)
|
||||||
|
from manim_ml.neural_network.layers.max_pooling_2d_to_feed_forward import MaxPooling2DToFeedForward
|
||||||
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
|
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
|
||||||
from .convolutional_2d import Convolutional2DLayer
|
from .convolutional_2d import Convolutional2DLayer
|
||||||
from .feed_forward_to_vector import FeedForwardToVector
|
from .feed_forward_to_vector import FeedForwardToVector
|
||||||
@ -40,9 +41,9 @@ connective_layers_list = (
|
|||||||
PairedQueryToFeedForward,
|
PairedQueryToFeedForward,
|
||||||
FeedForwardToVector,
|
FeedForwardToVector,
|
||||||
Convolutional2DToConvolutional2D,
|
Convolutional2DToConvolutional2D,
|
||||||
Convolutional2DToConvolutional2D,
|
|
||||||
ImageToConvolutional2DLayer,
|
ImageToConvolutional2DLayer,
|
||||||
Convolutional2DToFeedForward,
|
Convolutional2DToFeedForward,
|
||||||
Convolutional2DToMaxPooling2D,
|
Convolutional2DToMaxPooling2D,
|
||||||
MaxPooling2DToConvolutional2D,
|
MaxPooling2DToConvolutional2D,
|
||||||
|
MaxPooling2DToFeedForward,
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import Convolutional2DToFeedForward
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||||
|
|
||||||
|
class MaxPooling2DToFeedForward(Convolutional2DToFeedForward):
|
||||||
|
"""Feed Forward to Embedding Layer"""
|
||||||
|
|
||||||
|
input_class = MaxPooling2DLayer
|
||||||
|
output_class = FeedForwardLayer
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_layer: MaxPooling2DLayer,
|
||||||
|
output_layer: FeedForwardLayer,
|
||||||
|
passing_flash_color=ORANGE,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(input_layer, output_layer, **kwargs)
|
||||||
|
|
||||||
|
def construct_layer(
|
||||||
|
self,
|
||||||
|
input_layer: "NeuralNetworkLayer",
|
||||||
|
output_layer: "NeuralNetworkLayer",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
return super().construct_layer(input_layer, output_layer, **kwargs)
|
@ -4,7 +4,6 @@ from manim import *
|
|||||||
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
|
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
|
||||||
from manim_ml.neural_network.layers import connective_layers_list
|
from manim_ml.neural_network.layers import connective_layers_list
|
||||||
|
|
||||||
|
|
||||||
def get_connective_layer(input_layer, output_layer):
|
def get_connective_layer(input_layer, output_layer):
|
||||||
"""
|
"""
|
||||||
Deduces the relevant connective layer
|
Deduces the relevant connective layer
|
||||||
@ -13,16 +12,15 @@ def get_connective_layer(input_layer, output_layer):
|
|||||||
for candidate_class in connective_layers_list:
|
for candidate_class in connective_layers_list:
|
||||||
input_class = candidate_class.input_class
|
input_class = candidate_class.input_class
|
||||||
output_class = candidate_class.output_class
|
output_class = candidate_class.output_class
|
||||||
if isinstance(input_layer, input_class) and isinstance(
|
if isinstance(input_layer, input_class) and \
|
||||||
output_layer, output_class
|
isinstance(output_layer, output_class):
|
||||||
):
|
|
||||||
connective_layer_class = candidate_class
|
connective_layer_class = candidate_class
|
||||||
break
|
break
|
||||||
|
|
||||||
if connective_layer_class is None:
|
if connective_layer_class is None:
|
||||||
connective_layer_class = BlankConnective
|
connective_layer_class = BlankConnective
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Unrecognized input/output class pair: {input_class} and {output_class}"
|
f"Unrecognized input/output class pair: {input_layer} and {output_layer}"
|
||||||
)
|
)
|
||||||
# Make the instance now
|
# Make the instance now
|
||||||
connective_layer = connective_layer_class(input_layer, output_layer)
|
connective_layer = connective_layer_class(input_layer, output_layer)
|
||||||
|
@ -10,6 +10,7 @@ Example:
|
|||||||
NeuralNetwork(layer_node_count)
|
NeuralNetwork(layer_node_count)
|
||||||
"""
|
"""
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import numpy as np
|
||||||
from manim import *
|
from manim import *
|
||||||
|
|
||||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||||
@ -95,67 +96,22 @@ class NeuralNetwork(Group):
|
|||||||
previous_layer = self.input_layers[layer_index - 1]
|
previous_layer = self.input_layers[layer_index - 1]
|
||||||
current_layer = self.input_layers[layer_index]
|
current_layer = self.input_layers[layer_index]
|
||||||
current_layer.move_to(previous_layer.get_center())
|
current_layer.move_to(previous_layer.get_center())
|
||||||
# TODO Temp fix
|
if layout_direction == "left_to_right":
|
||||||
if isinstance(current_layer, EmbeddingLayer) or isinstance(
|
x_shift = previous_layer.get_width() / 2 \
|
||||||
previous_layer, EmbeddingLayer
|
+ current_layer.get_width() / 2 \
|
||||||
):
|
+ self.layer_spacing
|
||||||
if layout_direction == "left_to_right":
|
shift_vector = np.array([x_shift, 0, 0])
|
||||||
shift_vector = np.array(
|
elif layout_direction == "top_to_bottom":
|
||||||
[
|
y_shift = -((
|
||||||
(
|
previous_layer.get_width() / 2 \
|
||||||
previous_layer.get_width() / 2
|
+ current_layer.get_width() / 2
|
||||||
+ current_layer.get_width() / 2
|
) + self.layer_spacing)
|
||||||
- 0.2
|
|
||||||
),
|
shift_vector = np.array([0, y_shift, 0])
|
||||||
0,
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif layout_direction == "top_to_bottom":
|
|
||||||
shift_vector = np.array(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
-(
|
|
||||||
previous_layer.get_width() / 2
|
|
||||||
+ current_layer.get_width() / 2
|
|
||||||
- 0.2
|
|
||||||
),
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Unrecognized layout direction: {layout_direction}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if layout_direction == "left_to_right":
|
raise Exception(
|
||||||
shift_vector = np.array(
|
f"Unrecognized layout direction: {layout_direction}"
|
||||||
[
|
)
|
||||||
previous_layer.get_width() / 2
|
|
||||||
+ current_layer.get_width() / 2
|
|
||||||
+ self.layer_spacing,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif layout_direction == "top_to_bottom":
|
|
||||||
shift_vector = np.array(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
-(
|
|
||||||
(
|
|
||||||
previous_layer.get_width() / 2
|
|
||||||
+ current_layer.get_width() / 2
|
|
||||||
)
|
|
||||||
+ self.layer_spacing
|
|
||||||
),
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Unrecognized layout direction: {layout_direction}"
|
|
||||||
)
|
|
||||||
current_layer.shift(shift_vector)
|
current_layer.shift(shift_vector)
|
||||||
|
|
||||||
# After all layers have been placed place their activation functions
|
# After all layers have been placed place their activation functions
|
||||||
|
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="manim_ml",
|
name="manim_ml",
|
||||||
version="0.0.12",
|
version="0.0.14",
|
||||||
description=(" Machine Learning Animations in python using Manim."),
|
description=(" Machine Learning Animations in python using Manim."),
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user