General changes, got basic visualization of an activation function working for a

convolutinoal layer.
This commit is contained in:
Alec Helbling
2023-01-24 15:44:48 -05:00
parent 5291d9db8a
commit ce184af78e
34 changed files with 1575 additions and 479 deletions

View File

@ -99,7 +99,7 @@ class NeuralNetwork(Group):
for layer_index in range(1, len(self.input_layers)):
previous_layer = self.input_layers[layer_index - 1]
current_layer = self.input_layers[layer_index]
current_layer.move_to(previous_layer)
current_layer.move_to(previous_layer.get_center())
# TODO Temp fix
if isinstance(current_layer, EmbeddingLayer) or isinstance(
previous_layer, EmbeddingLayer
@ -156,6 +156,14 @@ class NeuralNetwork(Group):
f"Unrecognized layout direction: {layout_direction}"
)
current_layer.shift(shift_vector)
# Place activation function
if hasattr(current_layer, "activation_function"):
if not current_layer.activation_function is None:
current_layer.activation_function.next_to(
current_layer,
direction=UP
)
self.add(current_layer.activation_function)
def _construct_connective_layers(self):
"""Draws connecting lines between layers"""
@ -220,8 +228,8 @@ class NeuralNetwork(Group):
# Get the layer args
if isinstance(layer, ConnectiveLayer):
"""
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting and itself.
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting and itself.
"""
before_layer_args = {}
current_layer_args = {}
@ -244,11 +252,16 @@ class NeuralNetwork(Group):
current_layer_args = layer_args[layer]
# Perform the forward pass of the current layer
layer_forward_pass = layer.make_forward_pass_animation(
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
layer_args=current_layer_args,
run_time=per_layer_runtime,
**kwargs
)
all_animations.append(layer_forward_pass)
# Make the animation group
animation_group = Succession(*all_animations, lag_ratio=1.0)
animation_group = Succession(
*all_animations,
lag_ratio=1.0
)
return animation_group