mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-28 19:51:06 +08:00
Bug fixes and linting for the activation functions addition.
This commit is contained in:
@ -22,6 +22,7 @@ from manim_ml.neural_network.neural_network_transformations import (
|
||||
RemoveLayer,
|
||||
)
|
||||
|
||||
|
||||
class NeuralNetwork(Group):
|
||||
"""Neural Network Visualization Container Class"""
|
||||
|
||||
@ -53,17 +54,11 @@ class NeuralNetwork(Group):
|
||||
# Construct all of the layers
|
||||
self._construct_input_layers()
|
||||
# Place the layers
|
||||
self._place_layers(
|
||||
layout=layout,
|
||||
layout_direction=layout_direction
|
||||
)
|
||||
self._place_layers(layout=layout, layout_direction=layout_direction)
|
||||
# Make the connective layers
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Make overhead title
|
||||
self.title = Text(
|
||||
self.title_text,
|
||||
font_size=DEFAULT_FONT_SIZE / 2
|
||||
)
|
||||
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
|
||||
self.title.next_to(self, UP, 1.0)
|
||||
self.add(self.title)
|
||||
# Place layers at correct z index
|
||||
@ -76,7 +71,7 @@ class NeuralNetwork(Group):
|
||||
print(repr(self))
|
||||
|
||||
def _construct_input_layers(self):
|
||||
"""Constructs each of the input layers in context
|
||||
"""Constructs each of the input layers in context
|
||||
of their adjacent layers"""
|
||||
prev_layer = None
|
||||
next_layer = None
|
||||
@ -105,64 +100,82 @@ class NeuralNetwork(Group):
|
||||
previous_layer, EmbeddingLayer
|
||||
):
|
||||
if layout_direction == "left_to_right":
|
||||
shift_vector = np.array([
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
0,
|
||||
])
|
||||
shift_vector = np.array(
|
||||
[
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
0,
|
||||
]
|
||||
)
|
||||
elif layout_direction == "top_to_bottom":
|
||||
shift_vector = np.array([
|
||||
0,
|
||||
-(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
])
|
||||
shift_vector = np.array(
|
||||
[
|
||||
0,
|
||||
-(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
else:
|
||||
if layout_direction == "left_to_right":
|
||||
shift_vector = np.array([
|
||||
(
|
||||
shift_vector = np.array(
|
||||
[
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
+ self.layer_spacing,
|
||||
0,
|
||||
0,
|
||||
])
|
||||
+ self.layer_spacing,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
)
|
||||
elif layout_direction == "top_to_bottom":
|
||||
shift_vector = np.array([
|
||||
0,
|
||||
-(
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
+ self.layer_spacing
|
||||
),
|
||||
0,
|
||||
])
|
||||
shift_vector = np.array(
|
||||
[
|
||||
0,
|
||||
-(
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
+ self.layer_spacing
|
||||
),
|
||||
0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
current_layer.shift(shift_vector)
|
||||
|
||||
# After all layers have been placed place their activation functions
|
||||
for current_layer in self.input_layers:
|
||||
# Place activation function
|
||||
if hasattr(current_layer, "activation_function"):
|
||||
if not current_layer.activation_function is None:
|
||||
current_layer.activation_function.next_to(
|
||||
current_layer,
|
||||
direction=UP
|
||||
up_movement = np.array(
|
||||
[
|
||||
0,
|
||||
current_layer.get_height() / 2
|
||||
+ current_layer.activation_function.get_height() / 2
|
||||
+ 0.5 * self.layer_spacing,
|
||||
0,
|
||||
]
|
||||
)
|
||||
current_layer.activation_function.move_to(
|
||||
current_layer,
|
||||
)
|
||||
current_layer.activation_function.shift(up_movement)
|
||||
self.add(current_layer.activation_function)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
@ -228,8 +241,8 @@ class NeuralNetwork(Group):
|
||||
# Get the layer args
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
"""
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
"""
|
||||
before_layer_args = {}
|
||||
current_layer_args = {}
|
||||
@ -252,16 +265,11 @@ class NeuralNetwork(Group):
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(
|
||||
layer_args=current_layer_args,
|
||||
run_time=per_layer_runtime,
|
||||
**kwargs
|
||||
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = Succession(
|
||||
*all_animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
animation_group = Succession(*all_animations, lag_ratio=1.0)
|
||||
|
||||
return animation_group
|
||||
|
||||
@ -332,6 +340,7 @@ class NeuralNetwork(Group):
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
return string_repr
|
||||
|
||||
|
||||
class FeedForwardNeuralNetwork(NeuralNetwork):
|
||||
"""NeuralNetwork with just feed forward layers"""
|
||||
|
||||
|
Reference in New Issue
Block a user