Bug fixes and linting for the activation functions addition.

This commit is contained in:
Alec Helbling
2023-01-25 08:40:32 -05:00
parent ce184af78e
commit f56620f047
42 changed files with 1275 additions and 387 deletions

View File

@ -22,6 +22,7 @@ from manim_ml.neural_network.neural_network_transformations import (
RemoveLayer,
)
class NeuralNetwork(Group):
"""Neural Network Visualization Container Class"""
@ -53,17 +54,11 @@ class NeuralNetwork(Group):
# Construct all of the layers
self._construct_input_layers()
# Place the layers
self._place_layers(
layout=layout,
layout_direction=layout_direction
)
self._place_layers(layout=layout, layout_direction=layout_direction)
# Make the connective layers
self.connective_layers, self.all_layers = self._construct_connective_layers()
# Make overhead title
self.title = Text(
self.title_text,
font_size=DEFAULT_FONT_SIZE / 2
)
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
self.title.next_to(self, UP, 1.0)
self.add(self.title)
# Place layers at correct z index
@ -76,7 +71,7 @@ class NeuralNetwork(Group):
print(repr(self))
def _construct_input_layers(self):
"""Constructs each of the input layers in context
"""Constructs each of the input layers in context
of their adjacent layers"""
prev_layer = None
next_layer = None
@ -105,64 +100,82 @@ class NeuralNetwork(Group):
previous_layer, EmbeddingLayer
):
if layout_direction == "left_to_right":
shift_vector = np.array([
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
0,
])
shift_vector = np.array(
[
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
0,
]
)
elif layout_direction == "top_to_bottom":
shift_vector = np.array([
0,
-(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
])
shift_vector = np.array(
[
0,
-(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
else:
if layout_direction == "left_to_right":
shift_vector = np.array([
(
shift_vector = np.array(
[
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing,
0,
0,
])
+ self.layer_spacing,
0,
0,
]
)
elif layout_direction == "top_to_bottom":
shift_vector = np.array([
0,
-(
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing
),
0,
])
shift_vector = np.array(
[
0,
-(
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
current_layer.shift(shift_vector)
# After all layers have been placed place their activation functions
for current_layer in self.input_layers:
# Place activation function
if hasattr(current_layer, "activation_function"):
if not current_layer.activation_function is None:
current_layer.activation_function.next_to(
current_layer,
direction=UP
up_movement = np.array(
[
0,
current_layer.get_height() / 2
+ current_layer.activation_function.get_height() / 2
+ 0.5 * self.layer_spacing,
0,
]
)
current_layer.activation_function.move_to(
current_layer,
)
current_layer.activation_function.shift(up_movement)
self.add(current_layer.activation_function)
def _construct_connective_layers(self):
@ -228,8 +241,8 @@ class NeuralNetwork(Group):
# Get the layer args
if isinstance(layer, ConnectiveLayer):
"""
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting and itself.
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting and itself.
"""
before_layer_args = {}
current_layer_args = {}
@ -252,16 +265,11 @@ class NeuralNetwork(Group):
current_layer_args = layer_args[layer]
# Perform the forward pass of the current layer
layer_forward_pass = layer.make_forward_pass_animation(
layer_args=current_layer_args,
run_time=per_layer_runtime,
**kwargs
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
)
all_animations.append(layer_forward_pass)
# Make the animation group
animation_group = Succession(
*all_animations,
lag_ratio=1.0
)
animation_group = Succession(*all_animations, lag_ratio=1.0)
return animation_group
@ -332,6 +340,7 @@ class NeuralNetwork(Group):
string_repr = "NeuralNetwork([\n" + inner_string + "])"
return string_repr
class FeedForwardNeuralNetwork(NeuralNetwork):
"""NeuralNetwork with just feed forward layers"""