mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-05 19:15:25 +08:00
Removed Conv2D because it can be done using just Conv3D and renamed Conv3D to Conv2D to correspond to the spatial conv dimenson not the scene dimension, which is more inline with convention.
This commit is contained in:
@ -51,11 +51,20 @@ class NeuralNetwork(Group):
|
||||
self.layout_direction = layout_direction
|
||||
# TODO take layer_node_count [0, (1, 2), 0]
|
||||
# and make it have explicit distinct subspaces
|
||||
# Construct all of the layers
|
||||
self._construct_input_layers()
|
||||
# Place the layers
|
||||
self._place_layers(layout=layout, layout_direction=layout_direction)
|
||||
self._place_layers(
|
||||
layout=layout,
|
||||
layout_direction=layout_direction
|
||||
)
|
||||
# Make the connective layers
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Make overhead title
|
||||
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
|
||||
self.title = Text(
|
||||
self.title_text,
|
||||
font_size=DEFAULT_FONT_SIZE / 2
|
||||
)
|
||||
self.title.next_to(self, UP, 1.0)
|
||||
self.add(self.title)
|
||||
# Place layers at correct z index
|
||||
@ -67,6 +76,21 @@ class NeuralNetwork(Group):
|
||||
# Print neural network
|
||||
print(repr(self))
|
||||
|
||||
def _construct_input_layers(self):
|
||||
"""Constructs each of the input layers in context
|
||||
of their adjacent layers"""
|
||||
prev_layer = None
|
||||
next_layer = None
|
||||
# Go through all the input layers and run their construct method
|
||||
for layer_index in range(len(self.input_layers)):
|
||||
current_layer = self.input_layers[layer_index]
|
||||
if layer_index < len(self.input_layers) - 1:
|
||||
next_layer = self.input_layers[layer_index + 1]
|
||||
if layer_index > 0:
|
||||
prev_layer = self.input_layers[layer_index - 1]
|
||||
# Run the construct layer method for each
|
||||
current_layer.construct_layer(prev_layer, next_layer)
|
||||
|
||||
def _place_layers(self, layout="linear", layout_direction="top_to_bottom"):
|
||||
"""Creates the neural network"""
|
||||
# TODO implement more sophisticated custom layouts
|
||||
@ -300,7 +324,6 @@ class NeuralNetwork(Group):
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
return string_repr
|
||||
|
||||
|
||||
class FeedForwardNeuralNetwork(NeuralNetwork):
|
||||
"""NeuralNetwork with just feed forward layers"""
|
||||
|
||||
|
Reference in New Issue
Block a user