diff --git a/examples/logo/logo.py b/examples/logo/logo.py index e0c5575..05cf73a 100644 --- a/examples/logo/logo.py +++ b/examples/logo/logo.py @@ -2,17 +2,21 @@ Logo for Manim Machine Learning """ from manim import * -from manim_ml.neural_network.neural_network import FeedForwardNeuralNetwork -config.pixel_height = 500 -config.pixel_width = 500 +import manim_ml +manim_ml.config.color_scheme = "light_mode" + +from manim_ml.neural_network.architectures.feed_forward import FeedForwardNeuralNetwork + +config.pixel_height = 1000 +config.pixel_width = 1000 config.frame_height = 4.0 config.frame_width = 4.0 class ManimMLLogo(Scene): def construct(self): - self.text = Text("ManimML") + self.text = Text("ManimML", color=manim_ml.config.color_scheme.text_color) self.text.scale(1.0) self.neural_network = FeedForwardNeuralNetwork( [3, 5, 3, 6, 3], layer_spacing=0.3, node_color=BLUE @@ -23,21 +27,21 @@ class ManimMLLogo(Scene): self.logo_group = Group(self.text, self.neural_network) self.logo_group.scale(1.0) self.logo_group.move_to(ORIGIN) - self.play(Write(self.text)) - self.play(Create(self.neural_network)) + self.play(Write(self.text), run_time=1.0) + self.play(Create(self.neural_network), run_time=3.0) # self.surrounding_rectangle = SurroundingRectangle(self.logo_group, buff=0.3, color=BLUE) - underline = Underline(self.text, color=BLUE) + underline = Underline(self.text, color=BLACK) + animation_group = AnimationGroup( + self.neural_network.make_forward_pass_animation(run_time=5), + Create(underline), + # Create(self.surrounding_rectangle) + ) + # self.surrounding_rectangle = SurroundingRectangle(self.logo_group, buff=0.3, color=BLUE) + underline = Underline(self.text, color=BLACK) animation_group = AnimationGroup( self.neural_network.make_forward_pass_animation(run_time=5), Create(underline), # Create(self.surrounding_rectangle) ) - # self.surrounding_rectangle = SurroundingRectangle(self.logo_group, buff=0.3, color=BLUE) - underline = Underline(self.text, color=BLUE) - animation_group = AnimationGroup( - self.neural_network.make_forward_pass_animation(run_time=5), - Create(underline), - # Create(self.surrounding_rectangle) - ) - self.play(animation_group) + self.play(animation_group, runtime=5.0) self.wait(5) diff --git a/examples/logo/wide_logo.py b/examples/logo/wide_logo.py new file mode 100644 index 0000000..6ce80d5 --- /dev/null +++ b/examples/logo/wide_logo.py @@ -0,0 +1,48 @@ +""" + Logo for Manim Machine Learning +""" +from manim import * + +import manim_ml +manim_ml.config.color_scheme = "light_mode" + +from manim_ml.neural_network.architectures.feed_forward import FeedForwardNeuralNetwork + +config.pixel_height = 1000 +config.pixel_width = 2000 +config.frame_height = 4.0 +config.frame_width = 8.0 + + +class ManimMLLogo(Scene): + def construct(self): + self.text = Text("ManimML", color=manim_ml.config.color_scheme.text_color) + self.text.scale(1.0) + self.neural_network = FeedForwardNeuralNetwork( + [3, 5, 3, 6, 3], layer_spacing=0.3, node_color=BLUE + ) + self.neural_network.scale(0.8) + self.neural_network.next_to(self.text, RIGHT, buff=0.5) + # self.neural_network.move_to(self.text.get_right()) + # self.neural_network.shift(1.25 * DOWN) + self.logo_group = Group(self.text, self.neural_network) + self.logo_group.scale(1.0) + self.logo_group.move_to(ORIGIN) + self.play(Write(self.text), run_time=1.0) + self.play(Create(self.neural_network), run_time=3.0) + # self.surrounding_rectangle = SurroundingRectangle(self.logo_group, buff=0.3, color=BLUE) + underline = Underline(self.text, color=BLACK) + animation_group = AnimationGroup( + self.neural_network.make_forward_pass_animation(run_time=5), + Create(underline), + # Create(self.surrounding_rectangle) + ) + # self.surrounding_rectangle = SurroundingRectangle(self.logo_group, buff=0.3, color=BLUE) + underline = Underline(self.text, color=BLACK) + animation_group = AnimationGroup( + self.neural_network.make_forward_pass_animation(run_time=5), + Create(underline), + # Create(self.surrounding_rectangle) + ) + self.play(animation_group, runtime=5.0) + self.wait(5) diff --git a/manim_ml/__init__.py b/manim_ml/__init__.py index e69de29..fcfe20e 100644 --- a/manim_ml/__init__.py +++ b/manim_ml/__init__.py @@ -0,0 +1,30 @@ +import manim +from manim_ml.utils.colorschemes.colorschemes import light_mode, dark_mode, ColorScheme + +class ManimMLConfig: + + def __init__(self, default_color_scheme=light_mode): + self._color_scheme = default_color_scheme + + @property + def color_scheme(self): + return self._color_scheme + + @color_scheme.setter + def color_scheme(self, value): + if isinstance(value, str): + if value == "dark_mode": + self._color_scheme = dark_mode + elif value == "light_mode": + self._color_scheme = light_mode + else: + raise ValueError( + "Color scheme must be either 'dark_mode' or 'light_mode'" + ) + elif isinstance(value, ColorScheme): + self._color_scheme = value + + manim.config.background_color = self.color_scheme.background_color + +# These are accesible from the manim_ml namespace +config = ManimMLConfig() \ No newline at end of file diff --git a/manim_ml/neural_network/architectures/feed_forward.py b/manim_ml/neural_network/architectures/feed_forward.py index 3955530..a247720 100644 --- a/manim_ml/neural_network/architectures/feed_forward.py +++ b/manim_ml/neural_network/architectures/feed_forward.py @@ -1,11 +1,19 @@ +import manim_ml +from manim_ml.neural_network.neural_network import NeuralNetwork from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer class FeedForwardNeuralNetwork(NeuralNetwork): """NeuralNetwork with just feed forward layers""" - def __init__(self, layer_node_count, node_radius=0.08, node_color=BLUE, **kwargs): - # construct layers + def __init__( + self, + layer_node_count, + node_radius=0.08, + node_color=manim_ml.config.color_scheme.primary_color, + **kwargs + ): + # construct layer layers = [] for num_nodes in layer_node_count: layer = FeedForwardLayer( diff --git a/manim_ml/neural_network/layers/feed_forward.py b/manim_ml/neural_network/layers/feed_forward.py index be75ed0..1cb8c63 100644 --- a/manim_ml/neural_network/layers/feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward.py @@ -5,7 +5,7 @@ from manim_ml.neural_network.activation_functions.activation_function import ( ActivationFunction, ) from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer - +import manim_ml class FeedForwardLayer(VGroupNeuralNetworkLayer): """Handles rendering a layer for a neural network""" @@ -15,14 +15,14 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer): num_nodes, layer_buffer=SMALL_BUFF / 2, node_radius=0.08, - node_color=BLUE, - node_outline_color=WHITE, - rectangle_color=WHITE, + node_color=manim_ml.config.color_scheme.primary_color, + node_outline_color=manim_ml.config.color_scheme.secondary_color, + rectangle_color=manim_ml.config.color_scheme.secondary_color, node_spacing=0.3, - rectangle_fill_color=BLACK, + rectangle_fill_color=manim_ml.config.color_scheme.background_color, node_stroke_width=2.0, rectangle_stroke_width=2.0, - animation_dot_color=RED, + animation_dot_color=manim_ml.config.color_scheme.active_color, activation_function=None, **kwargs ): diff --git a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py index aef52c3..ee7a20e 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py @@ -4,7 +4,7 @@ import numpy as np from manim import * from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer - +import manim_ml class FeedForwardToFeedForward(ConnectiveLayer): """Layer for connecting FeedForward layer to FeedForwardLayer""" @@ -18,8 +18,8 @@ class FeedForwardToFeedForward(ConnectiveLayer): output_layer, passing_flash=True, dot_radius=0.05, - animation_dot_color=RED, - edge_color=WHITE, + animation_dot_color=manim_ml.config.color_scheme.active_color, + edge_color=manim_ml.config.color_scheme.secondary_color, edge_width=1.5, camera=None, **kwargs diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index 2c86ceb..22a3099 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -23,14 +23,12 @@ from manim_ml.neural_network.animations.neural_network_transformations import ( RemoveLayer, ) - class NeuralNetwork(Group): """Neural Network Visualization Container Class""" def __init__( self, input_layers, - edge_color=WHITE, layer_spacing=0.2, animation_dot_color=RED, edge_width=2.5, @@ -44,7 +42,6 @@ class NeuralNetwork(Group): self.input_layers_dict = self.make_input_layers_dict(input_layers) self.input_layers = ListGroup(*self.input_layers_dict.values()) self.edge_width = edge_width - self.edge_color = edge_color self.layer_spacing = layer_spacing self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius @@ -61,6 +58,8 @@ class NeuralNetwork(Group): self._place_layers(layout=layout, layout_direction=layout_direction) # Make the connective layers self.connective_layers, self.all_layers = self._construct_connective_layers() + # Place the connective layers + self._place_connective_layers() # Make overhead title self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2) self.title.next_to(self, UP, 1.0) @@ -221,6 +220,19 @@ class NeuralNetwork(Group): # Handle layering return connective_layers, all_layers + def _place_connective_layers(self): + """Places the connective layers + """ + # Place each of the connective layers halfway between the adjacent layers + for connective_layer in self.connective_layers: + layer_midpoint = ( + connective_layer.input_layer.get_center() + + connective_layer.output_layer.get_center() + ) / 2 + print(connective_layer.input_layer.get_center()) + print(connective_layer.output_layer.get_center()) + connective_layer.move_to(layer_midpoint) + def insert_layer(self, layer, insert_index): """Inserts a layer at the given index""" neural_network = self @@ -353,7 +365,9 @@ class NeuralNetwork(Group): layer.scale(scale_factor, **kwargs) # Place layers with scaled spacing self.layer_spacing *= scale_factor + # self.connective_layers, self.all_layers = self._construct_connective_layers() self._place_layers(layout=self.layout, layout_direction=self.layout_direction) + self._place_connective_layers() def filter_layers(self, function): """Filters layers of the network given function""" diff --git a/manim_ml/utils/colorschemes/__init__.py b/manim_ml/utils/colorschemes/__init__.py new file mode 100644 index 0000000..06f80da --- /dev/null +++ b/manim_ml/utils/colorschemes/__init__.py @@ -0,0 +1 @@ +from manim_ml.utils.colorschemes.colorschemes import light_mode, dark_mode \ No newline at end of file diff --git a/manim_ml/utils/colorschemes/colorschemes.py b/manim_ml/utils/colorschemes/colorschemes.py new file mode 100644 index 0000000..31d064c --- /dev/null +++ b/manim_ml/utils/colorschemes/colorschemes.py @@ -0,0 +1,26 @@ +from manim import * +from dataclasses import dataclass + +@dataclass +class ColorScheme: + primary_color: str + secondary_color: str + active_color: str + text_color: str + background_color: str + +dark_mode = ColorScheme( + primary_color=BLUE, + secondary_color=WHITE, + active_color=ORANGE, + text_color=WHITE, + background_color=BLACK +) + +light_mode = ColorScheme( + primary_color=BLUE, + secondary_color=BLACK, + active_color=ORANGE, + text_color=BLACK, + background_color=WHITE +) diff --git a/tests/test_color_scheme.py b/tests/test_color_scheme.py new file mode 100644 index 0000000..cc12853 --- /dev/null +++ b/tests/test_color_scheme.py @@ -0,0 +1,22 @@ +from manim import * + +from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer +import manim_ml + +manim_ml.config.color_scheme = "light_mode" + +config.pixel_height = 1200 +config.pixel_width = 1900 +config.frame_height = 6.0 +config.frame_width = 6.0 + +class FeedForwardScene(Scene): + def construct(self): + nn = NeuralNetwork([ + FeedForwardLayer(3), + FeedForwardLayer(5), + FeedForwardLayer(3) + ]) + + self.add(nn) + self.play(nn.make_forward_pass_animation())