mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-25 17:24:59 +08:00
General changes, got basic visualization of an activation function working for a
convolutinoal layer.
This commit is contained in:
@ -99,7 +99,7 @@ class NeuralNetwork(Group):
|
||||
for layer_index in range(1, len(self.input_layers)):
|
||||
previous_layer = self.input_layers[layer_index - 1]
|
||||
current_layer = self.input_layers[layer_index]
|
||||
current_layer.move_to(previous_layer)
|
||||
current_layer.move_to(previous_layer.get_center())
|
||||
# TODO Temp fix
|
||||
if isinstance(current_layer, EmbeddingLayer) or isinstance(
|
||||
previous_layer, EmbeddingLayer
|
||||
@ -156,6 +156,14 @@ class NeuralNetwork(Group):
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
current_layer.shift(shift_vector)
|
||||
# Place activation function
|
||||
if hasattr(current_layer, "activation_function"):
|
||||
if not current_layer.activation_function is None:
|
||||
current_layer.activation_function.next_to(
|
||||
current_layer,
|
||||
direction=UP
|
||||
)
|
||||
self.add(current_layer.activation_function)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
@ -220,8 +228,8 @@ class NeuralNetwork(Group):
|
||||
# Get the layer args
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
"""
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
"""
|
||||
before_layer_args = {}
|
||||
current_layer_args = {}
|
||||
@ -244,11 +252,16 @@ class NeuralNetwork(Group):
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(
|
||||
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
|
||||
layer_args=current_layer_args,
|
||||
run_time=per_layer_runtime,
|
||||
**kwargs
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = Succession(*all_animations, lag_ratio=1.0)
|
||||
animation_group = Succession(
|
||||
*all_animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
|
Reference in New Issue
Block a user