Changed default functionality of activation functions so they are all lined up with eachother.

This commit is contained in:
Alec Helbling
2023-04-09 20:15:07 -04:00
parent 2a50124ae2
commit ca7978929a
13 changed files with 397 additions and 39 deletions

View File

@ -1,3 +1,5 @@
from argparse import Namespace
from manim import *
import manim
from manim_ml.utils.colorschemes.colorschemes import light_mode, dark_mode, ColorScheme
@ -5,6 +7,14 @@ class ManimMLConfig:
def __init__(self, default_color_scheme=dark_mode):
self._color_scheme = default_color_scheme
self.three_d_config = Namespace(
three_d_x_rotation = 90 * DEGREES,
three_d_y_rotation = 0 * DEGREES,
rotation_angle = 75 * DEGREES,
rotation_axis = [0.02, 1.0, 0.0]
# rotation_axis = [0.0, 0.9, 0.0]
#rotation_axis = [0.0, 0.9, 0.0]
)
@property
def color_scheme(self):

View File

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
import random
import manim_ml.neural_network.activation_functions.relu as relu
import manim_ml
class ActivationFunction(ABC, VGroup):
"""Abstract parent class for defining activation functions"""
@ -16,9 +16,9 @@ class ActivationFunction(ABC, VGroup):
x_length=0.5,
y_length=0.3,
show_function_name=True,
active_color=ORANGE,
plot_color=BLUE,
rectangle_color=WHITE,
active_color=manim_ml.config.color_scheme.active_color,
plot_color=manim_ml.config.color_scheme.primary_color,
rectangle_color=manim_ml.config.color_scheme.secondary_color,
):
super(VGroup, self).__init__()
self.function_name = function_name
@ -46,6 +46,7 @@ class ActivationFunction(ABC, VGroup):
"include_numbers": False,
"stroke_width": 0.5,
"include_ticks": False,
"color": self.rectangle_color
},
)
self.add(self.axes)

View File

@ -4,7 +4,6 @@ from manim_ml.neural_network.activation_functions.activation_function import (
ActivationFunction,
)
class ReLUFunction(ActivationFunction):
"""Rectified Linear Unit Activation Function"""

View File

@ -2,7 +2,6 @@ import manim_ml
from manim_ml.neural_network.neural_network import NeuralNetwork
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
class FeedForwardNeuralNetwork(NeuralNetwork):
"""NeuralNetwork with just feed forward layers"""

View File

@ -5,6 +5,7 @@ from manim_ml.neural_network.activation_functions.activation_function import (
)
import numpy as np
from manim import *
import manim_ml
from manim_ml.neural_network.layers.parent_layers import (
ThreeDLayer,
@ -168,9 +169,9 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
# Rotate stuff properly
# normal_vector = self.feature_maps[0].get_normal_vector()
self.rotate(
ThreeDLayer.rotation_angle,
manim_ml.config.three_d_config.rotation_angle,
about_point=self.get_center(),
axis=ThreeDLayer.rotation_axis,
axis=manim_ml.config.three_d_config.rotation_axis,
)
self.construct_activation_function()

View File

@ -4,6 +4,7 @@ from manim import *
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
from manim_ml.utils.mobjects.gridded_rectangle import GriddedRectangle
import manim_ml
from manim.utils.space_ops import rotation_matrix
@ -14,7 +15,9 @@ def get_rotated_shift_vectors(input_layer, normalized=False):
right_shift = np.array([input_layer.cell_width, 0, 0])
down_shift = np.array([0, -input_layer.cell_width, 0])
# Make rotation matrix
rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis)
rot_mat = rotation_matrix(
manim_ml.config.three_d_config.rotation_angle,
manim_ml.config.three_d_config.rotation_axis)
# Rotate the vectors
right_shift = np.dot(right_shift, rot_mat.T)
down_shift = np.dot(down_shift, rot_mat.T)
@ -84,9 +87,9 @@ class Filters(VGroup):
)
# normal_vector = rectangle.get_normal_vector()
rectangle.rotate(
ThreeDLayer.rotation_angle,
manim_ml.config.three_d_config.rotation_angle,
about_point=rectangle.get_center(),
axis=ThreeDLayer.rotation_axis,
axis=manim_ml.config.three_d_config.rotation_axis,
)
# Move the rectangle to the corner of the feature map
rectangle.next_to(
@ -133,9 +136,9 @@ class Filters(VGroup):
)
# Rotate the rectangle
rectangle.rotate(
ThreeDLayer.rotation_angle,
manim_ml.config.three_d_config.rotation_angle,
about_point=rectangle.get_center(),
axis=ThreeDLayer.rotation_axis,
axis=manim_ml.config.three_d_config.rotation_axis,
)
# Move the rectangle to the corner location
rectangle.next_to(

View File

@ -10,6 +10,8 @@ from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeD
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
import manim_ml
class Uncreate(Create):
def __init__(
@ -123,9 +125,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
# of the conv maps
gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells)
gridded_rectangle_group.rotate(
ThreeDLayer.rotation_angle,
manim_ml.config.three_d_config.rotation_angle,
about_point=gridded_rectangle.get_center(),
axis=ThreeDLayer.rotation_axis,
axis=manim_ml.config.three_d_config.rotation_axis,
)
gridded_rectangle_group.next_to(
feature_map.get_corners_dict()["top_left"],
@ -163,9 +165,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
show_grid_lines=True,
)
output_gridded_rectangle.rotate(
ThreeDLayer.rotation_angle,
manim_ml.config.three_d_config.rotation_angle,
about_point=output_gridded_rectangle.get_center(),
axis=ThreeDLayer.rotation_axis,
axis=manim_ml.three_d_config.rotation_axis,
)
output_gridded_rectangle.move_to(
self.output_layer.feature_maps[feature_map_index].copy()

View File

@ -9,6 +9,8 @@ from manim_ml.neural_network.layers.parent_layers import (
)
from manim_ml.utils.mobjects.gridded_rectangle import GriddedRectangle
import manim_ml
class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Handles rendering a convolutional layer for a nn"""
@ -61,8 +63,8 @@ class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
# Make rotation of image
rotation = ApplyMethod(
image_mobject.rotate,
ThreeDLayer.rotation_angle,
ThreeDLayer.rotation_axis,
manim_ml.config.three_d_config.rotation_angle,
manim_ml.config.three_d_config.rotation_axis,
image_mobject.get_center(),
run_time=0.5,
)

View File

@ -5,7 +5,7 @@ from manim_ml.neural_network.layers.parent_layers import (
ThreeDLayer,
VGroupNeuralNetworkLayer,
)
import manim_ml
class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Max pooling layer for Convolutional2DLayer
@ -59,9 +59,9 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
)
self.add(self.feature_maps)
self.rotate(
ThreeDLayer.rotation_angle,
manim_ml.config.three_d_config.rotation_angle,
about_point=self.get_center(),
axis=ThreeDLayer.rotation_axis,
axis=manim_ml.config.three_d_config.rotation_axis
)
self.feature_map_size = (
input_layer.feature_map_size[0] / self.kernel_size,

View File

@ -56,12 +56,8 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
class ThreeDLayer(ABC):
"""Abstract class for 3D layers"""
pass
# Angle of ThreeD layers is static context
three_d_x_rotation = 90 * DEGREES # -90 * DEGREES
three_d_y_rotation = 0 * DEGREES # -10 * DEGREES
rotation_angle = 60 * DEGREES
rotation_axis = [0.0, 0.9, 0.0]
class ConnectiveLayer(VGroupNeuralNetworkLayer):

View File

@ -22,6 +22,7 @@ from manim_ml.neural_network.animations.neural_network_transformations import (
InsertLayer,
RemoveLayer,
)
import manim_ml
class NeuralNetwork(Group):
"""Neural Network Visualization Container Class"""
@ -30,7 +31,7 @@ class NeuralNetwork(Group):
self,
input_layers,
layer_spacing=0.2,
animation_dot_color=RED,
animation_dot_color=manim_ml.config.color_scheme.active_color,
edge_width=2.5,
dot_radius=0.03,
title=" ",
@ -173,24 +174,28 @@ class NeuralNetwork(Group):
current_layer.shift(shift_vector)
# After all layers have been placed place their activation functions
layer_max_height = max([layer.get_height() for layer in self.input_layers])
for current_layer in self.input_layers:
# Place activation function
if hasattr(current_layer, "activation_function"):
if not current_layer.activation_function is None:
up_movement = np.array(
[
0,
current_layer.get_height() / 2
+ current_layer.activation_function.get_height() / 2
+ 0.5 * self.layer_spacing,
0,
]
)
# Get max height of layer
up_movement = np.array([
0,
layer_max_height / 2
+ current_layer.activation_function.get_height() / 2
+ 0.5 * self.layer_spacing,
0,
])
current_layer.activation_function.move_to(
current_layer,
)
current_layer.activation_function.shift(up_movement)
self.add(current_layer.activation_function)
current_layer.activation_function.shift(
up_movement
)
self.add(
current_layer.activation_function
)
def _construct_connective_layers(self):
"""Draws connecting lines between layers"""

View File

@ -82,3 +82,6 @@ class ListGroup(Mobject):
if self.current_index < len(self.items):
return self.items[self.current_index]
raise StopIteration
def __repr__(self):
return f"ListGroup({self.items})"

View File

@ -0,0 +1,337 @@
r"""
A directive for including Manim videos in a Sphinx document
"""
from __future__ import annotations
import csv
import itertools as it
import os
import re
import shutil
import sys
from pathlib import Path
from timeit import timeit
import jinja2
from docutils import nodes
from docutils.parsers.rst import Directive, directives # type: ignore
from docutils.statemachine import StringList
from manim import QUALITIES
classnamedict = {}
class SkipManimNode(nodes.Admonition, nodes.Element):
"""Auxiliary node class that is used when the ``skip-manim`` tag is present
or ``.pot`` files are being built.
Skips rendering the manim directive and outputs a placeholder instead.
"""
pass
def visit(self, node, name=""):
self.visit_admonition(node, name)
if not isinstance(node[0], nodes.title):
node.insert(0, nodes.title("skip-manim", "Example Placeholder"))
def depart(self, node):
self.depart_admonition(node)
def process_name_list(option_input: str, reference_type: str) -> list[str]:
r"""Reformats a string of space separated class names
as a list of strings containing valid Sphinx references.
Tests
-----
::
>>> process_name_list("Tex TexTemplate", "class")
[':class:`~.Tex`', ':class:`~.TexTemplate`']
>>> process_name_list("Scene.play Mobject.rotate", "func")
[':func:`~.Scene.play`', ':func:`~.Mobject.rotate`']
"""
return [f":{reference_type}:`~.{name}`" for name in option_input.split()]
class ManimDirective(Directive):
r"""The manim directive, rendering videos while building
the documentation.
See the module docstring for documentation.
"""
has_content = True
required_arguments = 1
optional_arguments = 0
option_spec = {
"hide_source": bool,
"no_autoplay": bool,
"quality": lambda arg: directives.choice(
arg,
("low", "medium", "high", "fourk"),
),
"save_as_gif": bool,
"save_last_frame": bool,
"ref_modules": lambda arg: process_name_list(arg, "mod"),
"ref_classes": lambda arg: process_name_list(arg, "class"),
"ref_functions": lambda arg: process_name_list(arg, "func"),
"ref_methods": lambda arg: process_name_list(arg, "meth"),
}
final_argument_whitespace = True
def run(self):
# Rendering is skipped if the tag skip-manim is present,
# or if we are making the pot-files
should_skip = (
"skip-manim" in self.state.document.settings.env.app.builder.tags.tags
or self.state.document.settings.env.app.builder.name == "gettext"
)
if should_skip:
node = SkipManimNode()
self.state.nested_parse(
StringList(
[
f"Placeholder block for ``{self.arguments[0]}``.",
"",
".. code-block:: python",
"",
]
+ [" " + line for line in self.content]
),
self.content_offset,
node,
)
return [node]
from manim import config, tempconfig
global classnamedict
clsname = self.arguments[0]
if clsname not in classnamedict:
classnamedict[clsname] = 1
else:
classnamedict[clsname] += 1
hide_source = "hide_source" in self.options
no_autoplay = "no_autoplay" in self.options
save_as_gif = "save_as_gif" in self.options
save_last_frame = "save_last_frame" in self.options
assert not (save_as_gif and save_last_frame)
ref_content = (
self.options.get("ref_modules", [])
+ self.options.get("ref_classes", [])
+ self.options.get("ref_functions", [])
+ self.options.get("ref_methods", [])
)
if ref_content:
ref_block = "References: " + " ".join(ref_content)
else:
ref_block = ""
if "quality" in self.options:
quality = f'{self.options["quality"]}_quality'
else:
quality = "example_quality"
frame_rate = QUALITIES[quality]["frame_rate"]
pixel_height = QUALITIES[quality]["pixel_height"]
pixel_width = QUALITIES[quality]["pixel_width"]
state_machine = self.state_machine
document = state_machine.document
source_file_name = Path(document.attributes["source"])
source_rel_name = source_file_name.relative_to(setup.confdir)
source_rel_dir = source_rel_name.parents[0]
dest_dir = Path(setup.app.builder.outdir, source_rel_dir).absolute()
if not dest_dir.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
source_block = [
".. code-block:: python",
"",
" from manim import *\n",
*(" " + line for line in self.content),
]
source_block = "\n".join(source_block)
config.media_dir = (Path(setup.confdir) / "media").absolute()
config.images_dir = "{media_dir}/images"
config.video_dir = "{media_dir}/videos/{quality}"
output_file = f"{clsname}-{classnamedict[clsname]}"
config.assets_dir = Path("_static")
config.progress_bar = "none"
config.verbosity = "WARNING"
example_config = {
"frame_rate": frame_rate,
"no_autoplay": no_autoplay,
"pixel_height": pixel_height,
"pixel_width": pixel_width,
"save_last_frame": save_last_frame,
"write_to_movie": not save_last_frame,
"output_file": output_file,
}
if save_last_frame:
example_config["format"] = None
if save_as_gif:
example_config["format"] = "gif"
user_code = self.content
if user_code[0].startswith(">>> "): # check whether block comes from doctest
user_code = [
line[4:] for line in user_code if line.startswith((">>> ", "... "))
]
code = [
"from manim import *",
*user_code,
f"{clsname}().render()",
]
with tempconfig(example_config):
run_time = timeit(lambda: exec("\n".join(code), globals()), number=1)
video_dir = config.get_dir("video_dir")
images_dir = config.get_dir("images_dir")
_write_rendering_stats(
clsname,
run_time,
self.state.document.settings.env.docname,
)
# copy video file to output directory
if not (save_as_gif or save_last_frame):
filename = f"{output_file}.mp4"
filesrc = video_dir / filename
destfile = Path(dest_dir, filename)
shutil.copyfile(filesrc, destfile)
elif save_as_gif:
filename = f"{output_file}.gif"
filesrc = video_dir / filename
elif save_last_frame:
filename = f"{output_file}.png"
filesrc = images_dir / filename
else:
raise ValueError("Invalid combination of render flags received.")
rendered_template = jinja2.Template(TEMPLATE).render(
clsname=clsname,
clsname_lowercase=clsname.lower(),
hide_source=hide_source,
filesrc_rel=Path(filesrc).relative_to(setup.confdir).as_posix(),
no_autoplay=no_autoplay,
output_file=output_file,
save_last_frame=save_last_frame,
save_as_gif=save_as_gif,
source_block=source_block,
ref_block=ref_block,
)
state_machine.insert_input(
rendered_template.split("\n"),
source=document.attributes["source"],
)
return []
rendering_times_file_path = Path("../rendering_times.csv")
def _write_rendering_stats(scene_name, run_time, file_name):
with rendering_times_file_path.open("a") as file:
csv.writer(file).writerow(
[
re.sub(r"^(reference\/)|(manim\.)", "", file_name),
scene_name,
"%.3f" % run_time,
],
)
def _log_rendering_times(*args):
if rendering_times_file_path.exists():
with rendering_times_file_path.open() as file:
data = list(csv.reader(file))
if len(data) == 0:
sys.exit()
print("\nRendering Summary\n-----------------\n")
max_file_length = max(len(row[0]) for row in data)
for key, group in it.groupby(data, key=lambda row: row[0]):
key = key.ljust(max_file_length + 1, ".")
group = list(group)
if len(group) == 1:
row = group[0]
print(f"{key}{row[2].rjust(7, '.')}s {row[1]}")
continue
time_sum = sum(float(row[2]) for row in group)
print(
f"{key}{f'{time_sum:.3f}'.rjust(7, '.')}s => {len(group)} EXAMPLES",
)
for row in group:
print(f"{' '*(max_file_length)} {row[2].rjust(7)}s {row[1]}")
print("")
def _delete_rendering_times(*args):
if rendering_times_file_path.exists():
rendering_times_file_path.unlink()
def setup(app):
app.add_node(SkipManimNode, html=(visit, depart))
setup.app = app
setup.config = app.config
setup.confdir = app.confdir
app.add_directive("manim", ManimDirective)
app.connect("builder-inited", _delete_rendering_times)
app.connect("build-finished", _log_rendering_times)
metadata = {"parallel_read_safe": False, "parallel_write_safe": True}
return metadata
TEMPLATE = r"""
{% if not hide_source %}
.. raw:: html
<div id="{{ clsname_lowercase }}" class="admonition admonition-manim-example">
<p class="admonition-title">Example: {{ clsname }} <a class="headerlink" href="#{{ clsname_lowercase }}">¶</a></p>
{% endif %}
{% if not (save_as_gif or save_last_frame) %}
.. raw:: html
<video
class="manim-video"
controls
loop
{{ '' if no_autoplay else 'autoplay' }}
src="./{{ output_file }}.mp4">
</video>
{% elif save_as_gif %}
.. image:: /{{ filesrc_rel }}
:align: center
{% elif save_last_frame %}
.. image:: /{{ filesrc_rel }}
:align: center
{% endif %}
{% if not hide_source %}
{{ source_block }}
{{ ref_block }}
.. raw:: html
</div>
{% endif %}
"""