mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-28 19:51:06 +08:00
Finished oracle guidance video. Integrated various changes necessary to complete this.
This commit is contained in:
@ -13,6 +13,7 @@ from cv2 import AGAST_FEATURE_DETECTOR_NONMAX_SUPPRESSION
|
||||
from manim import *
|
||||
import warnings
|
||||
import textwrap
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
@ -59,7 +60,11 @@ class NeuralNetwork(Group):
|
||||
previous_layer = self.input_layers[layer_index - 1]
|
||||
current_layer = self.input_layers[layer_index]
|
||||
current_layer.move_to(previous_layer)
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + self.layer_spacing, 0, 0])
|
||||
# TODO Temp fix
|
||||
if isinstance(current_layer, EmbeddingLayer) or isinstance(previous_layer, EmbeddingLayer):
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2 - 0.2), 0, 0])
|
||||
else:
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + self.layer_spacing, 0, 0])
|
||||
current_layer.shift(shift_vector)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
|
Reference in New Issue
Block a user