mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-18 12:07:46 +08:00
Convolutional Layers
This commit is contained in:
@ -9,9 +9,7 @@ Example:
|
||||
# Create the object with default style settings
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
from cv2 import AGAST_FEATURE_DETECTOR_NONMAX_SUPPRESSION
|
||||
from manim import *
|
||||
import warnings
|
||||
import textwrap
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
|
||||
@ -104,6 +102,7 @@ class NeuralNetwork(Group):
|
||||
|
||||
def replace_layer(self, old_layer, new_layer):
|
||||
"""Replaces given layer object"""
|
||||
raise NotImplementedError()
|
||||
remove_animation = self.remove_layer(insert_index)
|
||||
insert_animation = self.insert_layer(layer, insert_index)
|
||||
# Make the animation
|
||||
@ -119,6 +118,7 @@ class NeuralNetwork(Group):
|
||||
**kwargs):
|
||||
"""Generates an animation for feed forward propagation"""
|
||||
all_animations = []
|
||||
per_layer_runtime = run_time/len(self.all_layers)
|
||||
for layer_index, layer in enumerate(self.all_layers):
|
||||
# Get the layer args
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
@ -139,10 +139,18 @@ class NeuralNetwork(Group):
|
||||
if layer in layer_args:
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(layer_args=current_layer_args, **kwargs)
|
||||
layer_forward_pass = layer.make_forward_pass_animation(
|
||||
layer_args=current_layer_args,
|
||||
run_time=per_layer_runtime,
|
||||
**kwargs
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0)
|
||||
animation_group = Succession(
|
||||
*all_animations,
|
||||
run_time=run_time,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
@ -176,7 +184,15 @@ class NeuralNetwork(Group):
|
||||
def set_z_index(self, z_index_value: float, family=False):
|
||||
"""Overriden set_z_index"""
|
||||
# Setting family=False stops sub-neural networks from inheriting parent z_index
|
||||
return super().set_z_index(z_index_value, family=False)
|
||||
for layer in self.all_layers:
|
||||
if not isinstance(NeuralNetwork):
|
||||
layer.set_z_index(z_index_value)
|
||||
|
||||
def scale(self, scale_factor, **kwargs):
|
||||
"""Overriden scale"""
|
||||
for layer in self.all_layers:
|
||||
layer.scale(scale_factor, **kwargs)
|
||||
# super().scale(scale_factor)
|
||||
|
||||
def __repr__(self, metadata=["z_index", "title_text"]):
|
||||
"""Print string representation of layers"""
|
||||
|
Reference in New Issue
Block a user