Refactored some code in neural network, added max pooling to feed forward

This commit is contained in:
Alec Helbling
2023-01-31 10:30:49 -05:00
parent ae6fd8a230
commit c14972fa4b
5 changed files with 49 additions and 67 deletions

View File

@ -10,6 +10,7 @@ from manim_ml.neural_network.layers.image_to_convolutional_2d import (
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import ( from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
MaxPooling2DToConvolutional2D, MaxPooling2DToConvolutional2D,
) )
from manim_ml.neural_network.layers.max_pooling_2d_to_feed_forward import MaxPooling2DToFeedForward
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
from .convolutional_2d import Convolutional2DLayer from .convolutional_2d import Convolutional2DLayer
from .feed_forward_to_vector import FeedForwardToVector from .feed_forward_to_vector import FeedForwardToVector
@ -40,9 +41,9 @@ connective_layers_list = (
PairedQueryToFeedForward, PairedQueryToFeedForward,
FeedForwardToVector, FeedForwardToVector,
Convolutional2DToConvolutional2D, Convolutional2DToConvolutional2D,
Convolutional2DToConvolutional2D,
ImageToConvolutional2DLayer, ImageToConvolutional2DLayer,
Convolutional2DToFeedForward, Convolutional2DToFeedForward,
Convolutional2DToMaxPooling2D, Convolutional2DToMaxPooling2D,
MaxPooling2DToConvolutional2D, MaxPooling2DToConvolutional2D,
MaxPooling2DToFeedForward,
) )

View File

@ -0,0 +1,27 @@
from manim import *
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import Convolutional2DToFeedForward
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
class MaxPooling2DToFeedForward(Convolutional2DToFeedForward):
"""Feed Forward to Embedding Layer"""
input_class = MaxPooling2DLayer
output_class = FeedForwardLayer
def __init__(
self,
input_layer: MaxPooling2DLayer,
output_layer: FeedForwardLayer,
passing_flash_color=ORANGE,
**kwargs
):
super().__init__(input_layer, output_layer, **kwargs)
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)

View File

@ -4,7 +4,6 @@ from manim import *
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
from manim_ml.neural_network.layers import connective_layers_list from manim_ml.neural_network.layers import connective_layers_list
def get_connective_layer(input_layer, output_layer): def get_connective_layer(input_layer, output_layer):
""" """
Deduces the relevant connective layer Deduces the relevant connective layer
@ -13,16 +12,15 @@ def get_connective_layer(input_layer, output_layer):
for candidate_class in connective_layers_list: for candidate_class in connective_layers_list:
input_class = candidate_class.input_class input_class = candidate_class.input_class
output_class = candidate_class.output_class output_class = candidate_class.output_class
if isinstance(input_layer, input_class) and isinstance( if isinstance(input_layer, input_class) and \
output_layer, output_class isinstance(output_layer, output_class):
):
connective_layer_class = candidate_class connective_layer_class = candidate_class
break break
if connective_layer_class is None: if connective_layer_class is None:
connective_layer_class = BlankConnective connective_layer_class = BlankConnective
warnings.warn( warnings.warn(
f"Unrecognized input/output class pair: {input_class} and {output_class}" f"Unrecognized input/output class pair: {input_layer} and {output_layer}"
) )
# Make the instance now # Make the instance now
connective_layer = connective_layer_class(input_layer, output_layer) connective_layer = connective_layer_class(input_layer, output_layer)

View File

@ -10,6 +10,7 @@ Example:
NeuralNetwork(layer_node_count) NeuralNetwork(layer_node_count)
""" """
import textwrap import textwrap
import numpy as np
from manim import * from manim import *
from manim_ml.neural_network.layers.embedding import EmbeddingLayer from manim_ml.neural_network.layers.embedding import EmbeddingLayer
@ -95,67 +96,22 @@ class NeuralNetwork(Group):
previous_layer = self.input_layers[layer_index - 1] previous_layer = self.input_layers[layer_index - 1]
current_layer = self.input_layers[layer_index] current_layer = self.input_layers[layer_index]
current_layer.move_to(previous_layer.get_center()) current_layer.move_to(previous_layer.get_center())
# TODO Temp fix if layout_direction == "left_to_right":
if isinstance(current_layer, EmbeddingLayer) or isinstance( x_shift = previous_layer.get_width() / 2 \
previous_layer, EmbeddingLayer + current_layer.get_width() / 2 \
): + self.layer_spacing
if layout_direction == "left_to_right": shift_vector = np.array([x_shift, 0, 0])
shift_vector = np.array( elif layout_direction == "top_to_bottom":
[ y_shift = -((
( previous_layer.get_width() / 2 \
previous_layer.get_width() / 2 + current_layer.get_width() / 2
+ current_layer.get_width() / 2 ) + self.layer_spacing)
- 0.2
), shift_vector = np.array([0, y_shift, 0])
0,
0,
]
)
elif layout_direction == "top_to_bottom":
shift_vector = np.array(
[
0,
-(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
else: else:
if layout_direction == "left_to_right": raise Exception(
shift_vector = np.array( f"Unrecognized layout direction: {layout_direction}"
[ )
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
+ self.layer_spacing,
0,
0,
]
)
elif layout_direction == "top_to_bottom":
shift_vector = np.array(
[
0,
-(
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
current_layer.shift(shift_vector) current_layer.shift(shift_vector)
# After all layers have been placed place their activation functions # After all layers have been placed place their activation functions

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="manim_ml", name="manim_ml",
version="0.0.12", version="0.0.14",
description=(" Machine Learning Animations in python using Manim."), description=(" Machine Learning Animations in python using Manim."),
packages=find_packages(), packages=find_packages(),
) )