Convolutional Layers

This commit is contained in:
Alec Helbling
2022-05-15 13:42:21 -04:00
parent 2ef4dcab44
commit 58aec269cf
11 changed files with 376 additions and 29 deletions

@ -9,9 +9,7 @@ Example:
# Create the object with default style settings
NeuralNetwork(layer_node_count)
"""
from cv2 import AGAST_FEATURE_DETECTOR_NONMAX_SUPPRESSION
from manim import *
import warnings
import textwrap
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
@ -104,6 +102,7 @@ class NeuralNetwork(Group):
def replace_layer(self, old_layer, new_layer):
"""Replaces given layer object"""
raise NotImplementedError()
remove_animation = self.remove_layer(insert_index)
insert_animation = self.insert_layer(layer, insert_index)
# Make the animation
@ -119,6 +118,7 @@ class NeuralNetwork(Group):
**kwargs):
"""Generates an animation for feed forward propagation"""
all_animations = []
per_layer_runtime = run_time/len(self.all_layers)
for layer_index, layer in enumerate(self.all_layers):
# Get the layer args
if isinstance(layer, ConnectiveLayer):
@ -139,10 +139,18 @@ class NeuralNetwork(Group):
if layer in layer_args:
current_layer_args = layer_args[layer]
# Perform the forward pass of the current layer
layer_forward_pass = layer.make_forward_pass_animation(layer_args=current_layer_args, **kwargs)
layer_forward_pass = layer.make_forward_pass_animation(
layer_args=current_layer_args,
run_time=per_layer_runtime,
**kwargs
)
all_animations.append(layer_forward_pass)
# Make the animation group
animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0)
animation_group = Succession(
*all_animations,
run_time=run_time,
lag_ratio=1.0
)
return animation_group
@ -176,7 +184,15 @@ class NeuralNetwork(Group):
def set_z_index(self, z_index_value: float, family=False):
"""Overriden set_z_index"""
# Setting family=False stops sub-neural networks from inheriting parent z_index
return super().set_z_index(z_index_value, family=False)
for layer in self.all_layers:
if not isinstance(NeuralNetwork):
layer.set_z_index(z_index_value)
def scale(self, scale_factor, **kwargs):
"""Overriden scale"""
for layer in self.all_layers:
layer.scale(scale_factor, **kwargs)
# super().scale(scale_factor)
def __repr__(self, metadata=["z_index", "title_text"]):
"""Print string representation of layers"""