Removed Conv2D because it can be done using just Conv3D and renamed Conv3D to Conv2D to correspond to the spatial conv dimenson not the scene dimension, which is more inline with convention.

This commit is contained in:
Alec Helbling
2023-01-15 14:35:26 +09:00
parent ba63116b37
commit 42b6e37b16
23 changed files with 358 additions and 467 deletions

View File

@ -51,11 +51,20 @@ class NeuralNetwork(Group):
self.layout_direction = layout_direction
# TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces
# Construct all of the layers
self._construct_input_layers()
# Place the layers
self._place_layers(layout=layout, layout_direction=layout_direction)
self._place_layers(
layout=layout,
layout_direction=layout_direction
)
# Make the connective layers
self.connective_layers, self.all_layers = self._construct_connective_layers()
# Make overhead title
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
self.title = Text(
self.title_text,
font_size=DEFAULT_FONT_SIZE / 2
)
self.title.next_to(self, UP, 1.0)
self.add(self.title)
# Place layers at correct z index
@ -67,6 +76,21 @@ class NeuralNetwork(Group):
# Print neural network
print(repr(self))
def _construct_input_layers(self):
"""Constructs each of the input layers in context
of their adjacent layers"""
prev_layer = None
next_layer = None
# Go through all the input layers and run their construct method
for layer_index in range(len(self.input_layers)):
current_layer = self.input_layers[layer_index]
if layer_index < len(self.input_layers) - 1:
next_layer = self.input_layers[layer_index + 1]
if layer_index > 0:
prev_layer = self.input_layers[layer_index - 1]
# Run the construct layer method for each
current_layer.construct_layer(prev_layer, next_layer)
def _place_layers(self, layout="linear", layout_direction="top_to_bottom"):
"""Creates the neural network"""
# TODO implement more sophisticated custom layouts
@ -300,7 +324,6 @@ class NeuralNetwork(Group):
string_repr = "NeuralNetwork([\n" + inner_string + "])"
return string_repr
class FeedForwardNeuralNetwork(NeuralNetwork):
"""NeuralNetwork with just feed forward layers"""