From ca7978929a4cfca354ab797eeb9f077f5f9bd3b2 Mon Sep 17 00:00:00 2001 From: Alec Helbling Date: Sun, 9 Apr 2023 20:15:07 -0400 Subject: [PATCH] Changed default functionality of activation functions so they are all lined up with eachother. --- manim_ml/__init__.py | 10 + .../activation_function.py | 9 +- .../activation_functions/relu.py | 1 - .../architectures/feed_forward.py | 1 - .../neural_network/layers/convolutional_2d.py | 5 +- .../convolutional_2d_to_convolutional_2d.py | 13 +- .../convolutional_2d_to_max_pooling_2d.py | 10 +- .../layers/image_to_convolutional_2d.py | 6 +- .../neural_network/layers/max_pooling_2d.py | 6 +- .../neural_network/layers/parent_layers.py | 6 +- manim_ml/neural_network/neural_network.py | 29 +- manim_ml/utils/mobjects/list_group.py | 3 + manim_ml/utils/testing/doc_directive.py | 337 ++++++++++++++++++ 13 files changed, 397 insertions(+), 39 deletions(-) create mode 100644 manim_ml/utils/testing/doc_directive.py diff --git a/manim_ml/__init__.py b/manim_ml/__init__.py index 4336383..136af09 100644 --- a/manim_ml/__init__.py +++ b/manim_ml/__init__.py @@ -1,3 +1,5 @@ +from argparse import Namespace +from manim import * import manim from manim_ml.utils.colorschemes.colorschemes import light_mode, dark_mode, ColorScheme @@ -5,6 +7,14 @@ class ManimMLConfig: def __init__(self, default_color_scheme=dark_mode): self._color_scheme = default_color_scheme + self.three_d_config = Namespace( + three_d_x_rotation = 90 * DEGREES, + three_d_y_rotation = 0 * DEGREES, + rotation_angle = 75 * DEGREES, + rotation_axis = [0.02, 1.0, 0.0] + # rotation_axis = [0.0, 0.9, 0.0] + #rotation_axis = [0.0, 0.9, 0.0] + ) @property def color_scheme(self): diff --git a/manim_ml/neural_network/activation_functions/activation_function.py b/manim_ml/neural_network/activation_functions/activation_function.py index 427bcab..bb28b1a 100644 --- a/manim_ml/neural_network/activation_functions/activation_function.py +++ b/manim_ml/neural_network/activation_functions/activation_function.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod import random import manim_ml.neural_network.activation_functions.relu as relu - +import manim_ml class ActivationFunction(ABC, VGroup): """Abstract parent class for defining activation functions""" @@ -16,9 +16,9 @@ class ActivationFunction(ABC, VGroup): x_length=0.5, y_length=0.3, show_function_name=True, - active_color=ORANGE, - plot_color=BLUE, - rectangle_color=WHITE, + active_color=manim_ml.config.color_scheme.active_color, + plot_color=manim_ml.config.color_scheme.primary_color, + rectangle_color=manim_ml.config.color_scheme.secondary_color, ): super(VGroup, self).__init__() self.function_name = function_name @@ -46,6 +46,7 @@ class ActivationFunction(ABC, VGroup): "include_numbers": False, "stroke_width": 0.5, "include_ticks": False, + "color": self.rectangle_color }, ) self.add(self.axes) diff --git a/manim_ml/neural_network/activation_functions/relu.py b/manim_ml/neural_network/activation_functions/relu.py index f776da5..ffdd891 100644 --- a/manim_ml/neural_network/activation_functions/relu.py +++ b/manim_ml/neural_network/activation_functions/relu.py @@ -4,7 +4,6 @@ from manim_ml.neural_network.activation_functions.activation_function import ( ActivationFunction, ) - class ReLUFunction(ActivationFunction): """Rectified Linear Unit Activation Function""" diff --git a/manim_ml/neural_network/architectures/feed_forward.py b/manim_ml/neural_network/architectures/feed_forward.py index a247720..5a1f059 100644 --- a/manim_ml/neural_network/architectures/feed_forward.py +++ b/manim_ml/neural_network/architectures/feed_forward.py @@ -2,7 +2,6 @@ import manim_ml from manim_ml.neural_network.neural_network import NeuralNetwork from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer - class FeedForwardNeuralNetwork(NeuralNetwork): """NeuralNetwork with just feed forward layers""" diff --git a/manim_ml/neural_network/layers/convolutional_2d.py b/manim_ml/neural_network/layers/convolutional_2d.py index b833ccb..48849d8 100644 --- a/manim_ml/neural_network/layers/convolutional_2d.py +++ b/manim_ml/neural_network/layers/convolutional_2d.py @@ -5,6 +5,7 @@ from manim_ml.neural_network.activation_functions.activation_function import ( ) import numpy as np from manim import * +import manim_ml from manim_ml.neural_network.layers.parent_layers import ( ThreeDLayer, @@ -168,9 +169,9 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): # Rotate stuff properly # normal_vector = self.feature_maps[0].get_normal_vector() self.rotate( - ThreeDLayer.rotation_angle, + manim_ml.config.three_d_config.rotation_angle, about_point=self.get_center(), - axis=ThreeDLayer.rotation_axis, + axis=manim_ml.config.three_d_config.rotation_axis, ) self.construct_activation_function() diff --git a/manim_ml/neural_network/layers/convolutional_2d_to_convolutional_2d.py b/manim_ml/neural_network/layers/convolutional_2d_to_convolutional_2d.py index 2baab5f..7decf76 100644 --- a/manim_ml/neural_network/layers/convolutional_2d_to_convolutional_2d.py +++ b/manim_ml/neural_network/layers/convolutional_2d_to_convolutional_2d.py @@ -4,6 +4,7 @@ from manim import * from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer from manim_ml.utils.mobjects.gridded_rectangle import GriddedRectangle +import manim_ml from manim.utils.space_ops import rotation_matrix @@ -14,7 +15,9 @@ def get_rotated_shift_vectors(input_layer, normalized=False): right_shift = np.array([input_layer.cell_width, 0, 0]) down_shift = np.array([0, -input_layer.cell_width, 0]) # Make rotation matrix - rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis) + rot_mat = rotation_matrix( + manim_ml.config.three_d_config.rotation_angle, + manim_ml.config.three_d_config.rotation_axis) # Rotate the vectors right_shift = np.dot(right_shift, rot_mat.T) down_shift = np.dot(down_shift, rot_mat.T) @@ -84,9 +87,9 @@ class Filters(VGroup): ) # normal_vector = rectangle.get_normal_vector() rectangle.rotate( - ThreeDLayer.rotation_angle, + manim_ml.config.three_d_config.rotation_angle, about_point=rectangle.get_center(), - axis=ThreeDLayer.rotation_axis, + axis=manim_ml.config.three_d_config.rotation_axis, ) # Move the rectangle to the corner of the feature map rectangle.next_to( @@ -133,9 +136,9 @@ class Filters(VGroup): ) # Rotate the rectangle rectangle.rotate( - ThreeDLayer.rotation_angle, + manim_ml.config.three_d_config.rotation_angle, about_point=rectangle.get_center(), - axis=ThreeDLayer.rotation_axis, + axis=manim_ml.config.three_d_config.rotation_axis, ) # Move the rectangle to the corner location rectangle.next_to( diff --git a/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py b/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py index 6eec4e2..6394a39 100644 --- a/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py +++ b/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py @@ -10,6 +10,8 @@ from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeD from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer +import manim_ml + class Uncreate(Create): def __init__( @@ -123,9 +125,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): # of the conv maps gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells) gridded_rectangle_group.rotate( - ThreeDLayer.rotation_angle, + manim_ml.config.three_d_config.rotation_angle, about_point=gridded_rectangle.get_center(), - axis=ThreeDLayer.rotation_axis, + axis=manim_ml.config.three_d_config.rotation_axis, ) gridded_rectangle_group.next_to( feature_map.get_corners_dict()["top_left"], @@ -163,9 +165,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): show_grid_lines=True, ) output_gridded_rectangle.rotate( - ThreeDLayer.rotation_angle, + manim_ml.config.three_d_config.rotation_angle, about_point=output_gridded_rectangle.get_center(), - axis=ThreeDLayer.rotation_axis, + axis=manim_ml.three_d_config.rotation_axis, ) output_gridded_rectangle.move_to( self.output_layer.feature_maps[feature_map_index].copy() diff --git a/manim_ml/neural_network/layers/image_to_convolutional_2d.py b/manim_ml/neural_network/layers/image_to_convolutional_2d.py index 6d3b769..bd29f37 100644 --- a/manim_ml/neural_network/layers/image_to_convolutional_2d.py +++ b/manim_ml/neural_network/layers/image_to_convolutional_2d.py @@ -9,6 +9,8 @@ from manim_ml.neural_network.layers.parent_layers import ( ) from manim_ml.utils.mobjects.gridded_rectangle import GriddedRectangle +import manim_ml + class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): """Handles rendering a convolutional layer for a nn""" @@ -61,8 +63,8 @@ class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): # Make rotation of image rotation = ApplyMethod( image_mobject.rotate, - ThreeDLayer.rotation_angle, - ThreeDLayer.rotation_axis, + manim_ml.config.three_d_config.rotation_angle, + manim_ml.config.three_d_config.rotation_axis, image_mobject.get_center(), run_time=0.5, ) diff --git a/manim_ml/neural_network/layers/max_pooling_2d.py b/manim_ml/neural_network/layers/max_pooling_2d.py index 5e5f3ea..1e6a13d 100644 --- a/manim_ml/neural_network/layers/max_pooling_2d.py +++ b/manim_ml/neural_network/layers/max_pooling_2d.py @@ -5,7 +5,7 @@ from manim_ml.neural_network.layers.parent_layers import ( ThreeDLayer, VGroupNeuralNetworkLayer, ) - +import manim_ml class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): """Max pooling layer for Convolutional2DLayer @@ -59,9 +59,9 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): ) self.add(self.feature_maps) self.rotate( - ThreeDLayer.rotation_angle, + manim_ml.config.three_d_config.rotation_angle, about_point=self.get_center(), - axis=ThreeDLayer.rotation_axis, + axis=manim_ml.config.three_d_config.rotation_axis ) self.feature_map_size = ( input_layer.feature_map_size[0] / self.kernel_size, diff --git a/manim_ml/neural_network/layers/parent_layers.py b/manim_ml/neural_network/layers/parent_layers.py index 0346cc5..a910bd3 100644 --- a/manim_ml/neural_network/layers/parent_layers.py +++ b/manim_ml/neural_network/layers/parent_layers.py @@ -56,12 +56,8 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer): class ThreeDLayer(ABC): """Abstract class for 3D layers""" - + pass # Angle of ThreeD layers is static context - three_d_x_rotation = 90 * DEGREES # -90 * DEGREES - three_d_y_rotation = 0 * DEGREES # -10 * DEGREES - rotation_angle = 60 * DEGREES - rotation_axis = [0.0, 0.9, 0.0] class ConnectiveLayer(VGroupNeuralNetworkLayer): diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index f10b69b..32dba26 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -22,6 +22,7 @@ from manim_ml.neural_network.animations.neural_network_transformations import ( InsertLayer, RemoveLayer, ) +import manim_ml class NeuralNetwork(Group): """Neural Network Visualization Container Class""" @@ -30,7 +31,7 @@ class NeuralNetwork(Group): self, input_layers, layer_spacing=0.2, - animation_dot_color=RED, + animation_dot_color=manim_ml.config.color_scheme.active_color, edge_width=2.5, dot_radius=0.03, title=" ", @@ -173,24 +174,28 @@ class NeuralNetwork(Group): current_layer.shift(shift_vector) # After all layers have been placed place their activation functions + layer_max_height = max([layer.get_height() for layer in self.input_layers]) for current_layer in self.input_layers: # Place activation function if hasattr(current_layer, "activation_function"): if not current_layer.activation_function is None: - up_movement = np.array( - [ - 0, - current_layer.get_height() / 2 - + current_layer.activation_function.get_height() / 2 - + 0.5 * self.layer_spacing, - 0, - ] - ) + # Get max height of layer + up_movement = np.array([ + 0, + layer_max_height / 2 + + current_layer.activation_function.get_height() / 2 + + 0.5 * self.layer_spacing, + 0, + ]) current_layer.activation_function.move_to( current_layer, ) - current_layer.activation_function.shift(up_movement) - self.add(current_layer.activation_function) + current_layer.activation_function.shift( + up_movement + ) + self.add( + current_layer.activation_function + ) def _construct_connective_layers(self): """Draws connecting lines between layers""" diff --git a/manim_ml/utils/mobjects/list_group.py b/manim_ml/utils/mobjects/list_group.py index 1178264..2a14e9c 100644 --- a/manim_ml/utils/mobjects/list_group.py +++ b/manim_ml/utils/mobjects/list_group.py @@ -82,3 +82,6 @@ class ListGroup(Mobject): if self.current_index < len(self.items): return self.items[self.current_index] raise StopIteration + + def __repr__(self): + return f"ListGroup({self.items})" diff --git a/manim_ml/utils/testing/doc_directive.py b/manim_ml/utils/testing/doc_directive.py new file mode 100644 index 0000000..7607b08 --- /dev/null +++ b/manim_ml/utils/testing/doc_directive.py @@ -0,0 +1,337 @@ +r""" +A directive for including Manim videos in a Sphinx document +""" +from __future__ import annotations + +import csv +import itertools as it +import os +import re +import shutil +import sys +from pathlib import Path +from timeit import timeit + +import jinja2 +from docutils import nodes +from docutils.parsers.rst import Directive, directives # type: ignore +from docutils.statemachine import StringList + +from manim import QUALITIES + +classnamedict = {} + +class SkipManimNode(nodes.Admonition, nodes.Element): + """Auxiliary node class that is used when the ``skip-manim`` tag is present + or ``.pot`` files are being built. + + Skips rendering the manim directive and outputs a placeholder instead. + """ + pass + +def visit(self, node, name=""): + self.visit_admonition(node, name) + if not isinstance(node[0], nodes.title): + node.insert(0, nodes.title("skip-manim", "Example Placeholder")) + +def depart(self, node): + self.depart_admonition(node) + +def process_name_list(option_input: str, reference_type: str) -> list[str]: + r"""Reformats a string of space separated class names + as a list of strings containing valid Sphinx references. + + Tests + ----- + + :: + >>> process_name_list("Tex TexTemplate", "class") + [':class:`~.Tex`', ':class:`~.TexTemplate`'] + >>> process_name_list("Scene.play Mobject.rotate", "func") + [':func:`~.Scene.play`', ':func:`~.Mobject.rotate`'] + """ + return [f":{reference_type}:`~.{name}`" for name in option_input.split()] + +class ManimDirective(Directive): + r"""The manim directive, rendering videos while building + the documentation. + + See the module docstring for documentation. + """ + has_content = True + required_arguments = 1 + optional_arguments = 0 + option_spec = { + "hide_source": bool, + "no_autoplay": bool, + "quality": lambda arg: directives.choice( + arg, + ("low", "medium", "high", "fourk"), + ), + "save_as_gif": bool, + "save_last_frame": bool, + "ref_modules": lambda arg: process_name_list(arg, "mod"), + "ref_classes": lambda arg: process_name_list(arg, "class"), + "ref_functions": lambda arg: process_name_list(arg, "func"), + "ref_methods": lambda arg: process_name_list(arg, "meth"), + } + final_argument_whitespace = True + + def run(self): + # Rendering is skipped if the tag skip-manim is present, + # or if we are making the pot-files + should_skip = ( + "skip-manim" in self.state.document.settings.env.app.builder.tags.tags + or self.state.document.settings.env.app.builder.name == "gettext" + ) + if should_skip: + node = SkipManimNode() + self.state.nested_parse( + StringList( + [ + f"Placeholder block for ``{self.arguments[0]}``.", + "", + ".. code-block:: python", + "", + ] + + [" " + line for line in self.content] + ), + self.content_offset, + node, + ) + return [node] + + from manim import config, tempconfig + + global classnamedict + + clsname = self.arguments[0] + if clsname not in classnamedict: + classnamedict[clsname] = 1 + else: + classnamedict[clsname] += 1 + + hide_source = "hide_source" in self.options + no_autoplay = "no_autoplay" in self.options + save_as_gif = "save_as_gif" in self.options + save_last_frame = "save_last_frame" in self.options + assert not (save_as_gif and save_last_frame) + + ref_content = ( + self.options.get("ref_modules", []) + + self.options.get("ref_classes", []) + + self.options.get("ref_functions", []) + + self.options.get("ref_methods", []) + ) + if ref_content: + ref_block = "References: " + " ".join(ref_content) + + else: + ref_block = "" + + if "quality" in self.options: + quality = f'{self.options["quality"]}_quality' + else: + quality = "example_quality" + frame_rate = QUALITIES[quality]["frame_rate"] + pixel_height = QUALITIES[quality]["pixel_height"] + pixel_width = QUALITIES[quality]["pixel_width"] + + state_machine = self.state_machine + document = state_machine.document + + source_file_name = Path(document.attributes["source"]) + source_rel_name = source_file_name.relative_to(setup.confdir) + source_rel_dir = source_rel_name.parents[0] + dest_dir = Path(setup.app.builder.outdir, source_rel_dir).absolute() + if not dest_dir.exists(): + dest_dir.mkdir(parents=True, exist_ok=True) + + source_block = [ + ".. code-block:: python", + "", + " from manim import *\n", + *(" " + line for line in self.content), + ] + source_block = "\n".join(source_block) + + config.media_dir = (Path(setup.confdir) / "media").absolute() + config.images_dir = "{media_dir}/images" + config.video_dir = "{media_dir}/videos/{quality}" + output_file = f"{clsname}-{classnamedict[clsname]}" + config.assets_dir = Path("_static") + config.progress_bar = "none" + config.verbosity = "WARNING" + + example_config = { + "frame_rate": frame_rate, + "no_autoplay": no_autoplay, + "pixel_height": pixel_height, + "pixel_width": pixel_width, + "save_last_frame": save_last_frame, + "write_to_movie": not save_last_frame, + "output_file": output_file, + } + if save_last_frame: + example_config["format"] = None + if save_as_gif: + example_config["format"] = "gif" + + user_code = self.content + if user_code[0].startswith(">>> "): # check whether block comes from doctest + user_code = [ + line[4:] for line in user_code if line.startswith((">>> ", "... ")) + ] + + code = [ + "from manim import *", + *user_code, + f"{clsname}().render()", + ] + + with tempconfig(example_config): + run_time = timeit(lambda: exec("\n".join(code), globals()), number=1) + video_dir = config.get_dir("video_dir") + images_dir = config.get_dir("images_dir") + + _write_rendering_stats( + clsname, + run_time, + self.state.document.settings.env.docname, + ) + + # copy video file to output directory + if not (save_as_gif or save_last_frame): + filename = f"{output_file}.mp4" + filesrc = video_dir / filename + destfile = Path(dest_dir, filename) + shutil.copyfile(filesrc, destfile) + elif save_as_gif: + filename = f"{output_file}.gif" + filesrc = video_dir / filename + elif save_last_frame: + filename = f"{output_file}.png" + filesrc = images_dir / filename + else: + raise ValueError("Invalid combination of render flags received.") + rendered_template = jinja2.Template(TEMPLATE).render( + clsname=clsname, + clsname_lowercase=clsname.lower(), + hide_source=hide_source, + filesrc_rel=Path(filesrc).relative_to(setup.confdir).as_posix(), + no_autoplay=no_autoplay, + output_file=output_file, + save_last_frame=save_last_frame, + save_as_gif=save_as_gif, + source_block=source_block, + ref_block=ref_block, + ) + state_machine.insert_input( + rendered_template.split("\n"), + source=document.attributes["source"], + ) + + return [] + + +rendering_times_file_path = Path("../rendering_times.csv") + + +def _write_rendering_stats(scene_name, run_time, file_name): + with rendering_times_file_path.open("a") as file: + csv.writer(file).writerow( + [ + re.sub(r"^(reference\/)|(manim\.)", "", file_name), + scene_name, + "%.3f" % run_time, + ], + ) + + +def _log_rendering_times(*args): + if rendering_times_file_path.exists(): + with rendering_times_file_path.open() as file: + data = list(csv.reader(file)) + if len(data) == 0: + sys.exit() + + print("\nRendering Summary\n-----------------\n") + + max_file_length = max(len(row[0]) for row in data) + for key, group in it.groupby(data, key=lambda row: row[0]): + key = key.ljust(max_file_length + 1, ".") + group = list(group) + if len(group) == 1: + row = group[0] + print(f"{key}{row[2].rjust(7, '.')}s {row[1]}") + continue + time_sum = sum(float(row[2]) for row in group) + print( + f"{key}{f'{time_sum:.3f}'.rjust(7, '.')}s => {len(group)} EXAMPLES", + ) + for row in group: + print(f"{' '*(max_file_length)} {row[2].rjust(7)}s {row[1]}") + print("") + + +def _delete_rendering_times(*args): + if rendering_times_file_path.exists(): + rendering_times_file_path.unlink() + + +def setup(app): + app.add_node(SkipManimNode, html=(visit, depart)) + + setup.app = app + setup.config = app.config + setup.confdir = app.confdir + + app.add_directive("manim", ManimDirective) + + app.connect("builder-inited", _delete_rendering_times) + app.connect("build-finished", _log_rendering_times) + + metadata = {"parallel_read_safe": False, "parallel_write_safe": True} + return metadata + + +TEMPLATE = r""" +{% if not hide_source %} +.. raw:: html + +
+

Example: {{ clsname }} ΒΆ

+ +{% endif %} + +{% if not (save_as_gif or save_last_frame) %} +.. raw:: html + + + +{% elif save_as_gif %} +.. image:: /{{ filesrc_rel }} + :align: center + +{% elif save_last_frame %} +.. image:: /{{ filesrc_rel }} + :align: center + +{% endif %} +{% if not hide_source %} +{{ source_block }} + +{{ ref_block }} + +.. raw:: html + +
+ +{% endif %} +"""