Reformatted the code using black, allowd for different orientation NNs, made an option for highlighting the active filter in a CNN forward pass.

This commit is contained in:
Alec Helbling
2023-01-09 15:52:37 +09:00
parent 39b0b133ce
commit ba63116b37
19 changed files with 485 additions and 283 deletions

View File

@ -0,0 +1,45 @@
from manim import *
from PIL import Image
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.image import ImageLayer
from manim_ml.neural_network.layers.parent_layers import ThreeDLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 800
config.frame_height = 6.0
config.frame_width = 6.0
class CombinedScene(ThreeDScene):
def construct(self):
image = Image.open("../../assets/doggo.jpeg")
numpy_image = np.asarray(image)
# Rotate the Three D layer position
ThreeDLayer.rotation_angle = 15 * DEGREES
ThreeDLayer.rotation_axis = [1, -1.0, 0]
# Make nn
nn = NeuralNetwork(
[
ImageLayer(numpy_image, height=1.5),
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
Convolutional3DLayer(3, 5, 5, 1, 1, filter_spacing=0.18),
],
layer_spacing=0.25,
layout_direction="top_to_bottom",
)
# Center the nn
nn.move_to(ORIGIN)
nn.scale(1.5)
self.add(nn)
# Play animation
forward_pass = nn.make_forward_pass_animation(
corner_pulses=False,
all_filters_at_once=False,
highlight_active_feature_map=True,
)
self.wait(1)
self.play(forward_pass)

View File

@ -3,8 +3,10 @@ import numpy as np
from collections import deque from collections import deque
from sklearn.tree import _tree as ctree from sklearn.tree import _tree as ctree
class AABB: class AABB:
"""Axis-aligned bounding box""" """Axis-aligned bounding box"""
def __init__(self, n_features): def __init__(self, n_features):
self.limits = np.array([[-np.inf, np.inf]] * n_features) self.limits = np.array([[-np.inf, np.inf]] * n_features)
@ -18,6 +20,7 @@ class AABB:
return left, right return left, right
def tree_bounds(tree, n_features=None): def tree_bounds(tree, n_features=None):
"""Compute final decision rule for each node in tree""" """Compute final decision rule for each node in tree"""
if n_features is None: if n_features is None:
@ -33,6 +36,7 @@ def tree_bounds(tree, n_features=None):
queue.extend([l, r]) queue.extend([l, r])
return aabbs return aabbs
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None): def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
"""Extract decision areas. """Extract decision areas.
@ -69,21 +73,27 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None)
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2]) rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
return rectangles return rectangles
def plot_areas(rectangles): def plot_areas(rectangles):
for rect in rectangles: for rect in rectangles:
color = ['b', 'r'][int(rect[4])] color = ["b", "r"][int(rect[4])]
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]) print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
rp = Rectangle([rect[0], rect[2]], rp = Rectangle(
[rect[0], rect[2]],
rect[1] - rect[0], rect[1] - rect[0],
rect[3] - rect[2], color=color, alpha=0.3) rect[3] - rect[2],
color=color,
alpha=0.3,
)
plt.gca().add_artist(rp) plt.gca().add_artist(rp)
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]): def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# get all polygons of each color # get all polygons of each color
polygon_dict = { polygon_dict = {
str(BLUE).lower(): [], str(BLUE).lower(): [],
str(GREEN).lower(): [], str(GREEN).lower(): [],
str(ORANGE).lower():[] str(ORANGE).lower(): [],
} }
for polygon in all_polygons: for polygon in all_polygons:
print(polygon_dict) print(polygon_dict)
@ -143,8 +153,10 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# Remove implementation-markers from the polygon. # Remove implementation-markers from the polygon.
poly = [point for point, _ in polygon] poly = [point for point, _ in polygon]
for vertex in poly: for vertex in poly:
if vertex in edges_h: edges_h.pop(vertex) if vertex in edges_h:
if vertex in edges_v: edges_v.pop(vertex) edges_h.pop(vertex)
if vertex in edges_v:
edges_v.pop(vertex)
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0) polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
return_polygons.append(polygon) return_polygons.append(polygon)
return return_polygons return return_polygons

View File

@ -6,18 +6,23 @@
TODO reimplement the decision 2D decision tree surface drawing. TODO reimplement the decision 2D decision tree surface drawing.
""" """
from manim import * from manim import *
from manim_ml.decision_tree.classification_areas import compute_decision_areas, merge_overlapping_polygons from manim_ml.decision_tree.classification_areas import (
compute_decision_areas,
merge_overlapping_polygons,
)
import manim_ml.decision_tree.helpers as helpers import manim_ml.decision_tree.helpers as helpers
from manim_ml.one_to_one_sync import OneToOneSync from manim_ml.one_to_one_sync import OneToOneSync
import numpy as np import numpy as np
from PIL import Image from PIL import Image
class LeafNode(Group): class LeafNode(Group):
"""Leaf node in tree""" """Leaf node in tree"""
def __init__(self, class_index, display_type="image", class_image_paths=[], def __init__(
class_colors=[]): self, class_index, display_type="image", class_image_paths=[], class_colors=[]
):
super().__init__() super().__init__()
self.display_type = display_type self.display_type = display_type
self.class_image_paths = class_image_paths self.class_image_paths = class_image_paths
@ -39,13 +44,14 @@ class LeafNode(Group):
width=node.width + 0.05, width=node.width + 0.05,
height=node.height + 0.05, height=node.height + 0.05,
color=self.class_colors[class_index], color=self.class_colors[class_index],
stroke_width=6 stroke_width=6,
) )
rectangle.move_to(node.get_center()) rectangle.move_to(node.get_center())
rectangle.shift([-0.02, 0.02, 0]) rectangle.shift([-0.02, 0.02, 0])
self.add(rectangle) self.add(rectangle)
self.add(node) self.add(node)
class SplitNode(VGroup): class SplitNode(VGroup):
"""Node for splitting decision in tree""" """Node for splitting decision in tree"""
@ -53,25 +59,24 @@ class SplitNode(VGroup):
super().__init__() super().__init__()
node_text = f"{feature}\n<= {threshold:.2f} cm" node_text = f"{feature}\n<= {threshold:.2f} cm"
# Draw decision text # Draw decision text
decision_text = Text( decision_text = Text(node_text, color=WHITE)
node_text,
color=WHITE
)
# Draw the surrounding box # Draw the surrounding box
bounding_box = SurroundingRectangle( bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
decision_text,
buff=0.3,
color=WHITE
)
self.add(bounding_box) self.add(bounding_box)
self.add(decision_text) self.add(decision_text)
class DecisionTreeDiagram(Group): class DecisionTreeDiagram(Group):
"""Decision Tree Diagram Class for Manim""" """Decision Tree Diagram Class for Manim"""
def __init__(self, sklearn_tree, feature_names=None, def __init__(
class_names=None, class_images_paths=None, self,
class_colors=[RED, GREEN, BLUE]): sklearn_tree,
feature_names=None,
class_names=None,
class_images_paths=None,
class_colors=[RED, GREEN, BLUE],
):
super().__init__() super().__init__()
self.tree = sklearn_tree self.tree = sklearn_tree
self.feature_names = feature_names self.feature_names = feature_names
@ -87,14 +92,13 @@ class DecisionTreeDiagram(Group):
node_index, node_index,
): ):
"""Make node""" """Make node"""
is_split_node = self.tree.children_left[node_index] != self.tree.children_right[node_index] is_split_node = (
self.tree.children_left[node_index] != self.tree.children_right[node_index]
)
if is_split_node: if is_split_node:
node_feature = self.tree.feature[node_index] node_feature = self.tree.feature[node_index]
node_threshold = self.tree.threshold[node_index] node_threshold = self.tree.threshold[node_index]
node = SplitNode( node = SplitNode(self.feature_names[node_feature], node_threshold)
self.feature_names[node_feature],
node_threshold
)
else: else:
# Get the most abundant class for the given leaf node # Get the most abundant class for the given leaf node
# Make the leaf node object # Make the leaf node object
@ -102,7 +106,7 @@ class DecisionTreeDiagram(Group):
node = LeafNode( node = LeafNode(
class_index=tree_class_index, class_index=tree_class_index,
class_colors=self.class_colors, class_colors=self.class_colors,
class_image_paths=self.class_image_paths class_image_paths=self.class_image_paths,
) )
return node return node
@ -113,11 +117,7 @@ class DecisionTreeDiagram(Group):
bottom_node_top_location = bottom.get_center() bottom_node_top_location = bottom.get_center()
bottom_node_top_location[1] += bottom.height / 2 bottom_node_top_location[1] += bottom.height / 2
line = Line( line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
top_node_bottom_location,
bottom_node_top_location,
color=WHITE
)
return line return line
@ -143,31 +143,49 @@ class DecisionTreeDiagram(Group):
# traverse tree # traverse tree
def recurse(node_index, depth, direction, parent_object, parent_node): def recurse(node_index, depth, direction, parent_object, parent_node):
# make the node object # make the node object
is_leaf = self.tree.children_left[node_index] == self.tree.children_right[node_index] is_leaf = (
self.tree.children_left[node_index]
== self.tree.children_right[node_index]
)
node_object = self._make_node(node_index=node_index) node_object = self._make_node(node_index=node_index)
nodes_map[node_index] = node_object nodes_map[node_index] = node_object
node_height = node_object.height node_height = node_object.height
# set the node position # set the node position
direction_factor = -1 if direction == "left" else 1 direction_factor = -1 if direction == "left" else 1
shift_right_amount = 0.9 * direction_factor * scale_factor * tree_width / (2 ** depth) / 2 shift_right_amount = (
0.9 * direction_factor * scale_factor * tree_width / (2**depth) / 2
)
if is_leaf: if is_leaf:
shift_down_amount = -1.0 * scale_factor * node_height shift_down_amount = -1.0 * scale_factor * node_height
else: else:
shift_down_amount = -1.8 * scale_factor * node_height shift_down_amount = -1.8 * scale_factor * node_height
node_object \ node_object.match_x(parent_object).match_y(parent_object).shift(
.match_x(parent_object) \ [shift_right_amount, shift_down_amount, 0]
.match_y(parent_object) \ )
.shift([shift_right_amount, shift_down_amount, 0])
tree_group.add(node_object) tree_group.add(node_object)
# make a connection # make a connection
connection = self._make_connection(parent_object, node_object, is_leaf=is_leaf) connection = self._make_connection(
parent_object, node_object, is_leaf=is_leaf
)
edge_name = str(parent_node) + "," + str(node_index) edge_name = str(parent_node) + "," + str(node_index)
edge_map[edge_name] = connection edge_map[edge_name] = connection
tree_group.add(connection) tree_group.add(connection)
# recurse # recurse
if not is_leaf: if not is_leaf:
recurse(self.tree.children_left[node_index], depth + 1, "left", node_object, node_index) recurse(
recurse(self.tree.children_right[node_index], depth + 1, "right", node_object, node_index) self.tree.children_left[node_index],
depth + 1,
"left",
node_object,
node_index,
)
recurse(
self.tree.children_right[node_index],
depth + 1,
"right",
node_object,
node_index,
)
recurse(self.tree.children_left[0], 1, "left", root_node, 0) recurse(self.tree.children_left[0], 1, "left", root_node, 0)
recurse(self.tree.children_right[0], 1, "right", root_node, 0) recurse(self.tree.children_right[0], 1, "right", root_node, 0)
@ -185,9 +203,7 @@ class DecisionTreeDiagram(Group):
# Compute parent mapping # Compute parent mapping
parent_mapping = helpers.compute_node_to_parent_mapping(self.tree) parent_mapping = helpers.compute_node_to_parent_mapping(self.tree)
# Create the root node # Create the root node
animations.append( animations.append(Create(self.nodes_map[0]))
Create(self.nodes_map[0])
)
# Iterate through the nodes # Iterate through the nodes
queue = [0] queue = [0]
while len(queue) > 0: while len(queue) > 0:
@ -211,21 +227,16 @@ class DecisionTreeDiagram(Group):
FadeIn(right_node), FadeIn(right_node),
Create(left_parent_edge), Create(left_parent_edge),
Create(right_parent_edge), Create(right_parent_edge),
lag_ratio=0.0 lag_ratio=0.0,
)
animations.append(
split_animation
) )
animations.append(split_animation)
# Add the children to the queue # Add the children to the queue
if left_child != -1: if left_child != -1:
queue.append(left_child) queue.append(left_child)
if right_child != -1: if right_child != -1:
queue.append(right_child) queue.append(right_child)
return AnimationGroup( return AnimationGroup(*animations, lag_ratio=1.0)
*animations,
lag_ratio=1.0
)
@override_animation(Create) @override_animation(Create)
def create_decision_tree(self, traversal_order="bfs"): def create_decision_tree(self, traversal_order="bfs"):
@ -237,8 +248,8 @@ class DecisionTreeDiagram(Group):
else: else:
raise Exception(f"Uncrecognized traversal: {traversal_order}") raise Exception(f"Uncrecognized traversal: {traversal_order}")
class IrisDatasetPlot(VGroup):
class IrisDatasetPlot(VGroup):
def __init__(self, iris): def __init__(self, iris):
points = iris.data[:, 0:2] points = iris.data[:, 0:2]
labels = iris.feature_names labels = iris.feature_names
@ -249,19 +260,13 @@ class IrisDatasetPlot(VGroup):
self.axes_group = self._make_axes_group(points, labels) self.axes_group = self._make_axes_group(points, labels)
# Make legend # Make legend
self.legend_group = self._make_legend( self.legend_group = self._make_legend(
[BLUE, ORANGE, GREEN], [BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
iris.target_names,
self.axes_group
) )
# Make title # Make title
# title_text = "Iris Dataset Plot" # title_text = "Iris Dataset Plot"
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0]) # self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
# Make all group # Make all group
self.all_group = Group( self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
self.point_group,
self.axes_group,
self.legend_group
)
# scale the groups # scale the groups
self.point_group.scale(1.6) self.point_group.scale(1.6)
self.point_group.match_x(self.axes_group) self.point_group.match_x(self.axes_group)
@ -279,7 +284,7 @@ class IrisDatasetPlot(VGroup):
Create(self.axes_group, run_time=2), Create(self.axes_group, run_time=2),
# add title # add title
# Create(self.title), # Create(self.title),
Create(self.legend_group) Create(self.legend_group),
) )
return animation_group return animation_group
@ -300,7 +305,9 @@ class IrisDatasetPlot(VGroup):
setosa = Text("Setosa", color=BLUE) setosa = Text("Setosa", color=BLUE)
verisicolor = Text("Verisicolor", color=ORANGE) verisicolor = Text("Verisicolor", color=ORANGE)
virginica = Text("Virginica", color=GREEN) virginica = Text("Virginica", color=GREEN)
labels = VGroup(setosa, verisicolor, virginica).arrange(direction=RIGHT, aligned_edge=LEFT, buff=2.0) labels = VGroup(setosa, verisicolor, virginica).arrange(
direction=RIGHT, aligned_edge=LEFT, buff=2.0
)
labels.scale(0.5) labels.scale(0.5)
legend_group.add(labels) legend_group.add(labels)
# surrounding rectangle # surrounding rectangle
@ -314,10 +321,14 @@ class IrisDatasetPlot(VGroup):
return legend_group return legend_group
def _make_axes_group(self, points, labels, font='Source Han Sans', font_scale=0.75): def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
axes_group = VGroup() axes_group = VGroup()
# make the axes # make the axes
x_range = [np.amin(points, axis=0)[0] - 0.2, np.amax(points, axis=0)[0] - 0.2, 0.5] x_range = [
np.amin(points, axis=0)[0] - 0.2,
np.amax(points, axis=0)[0] - 0.2,
0.5,
]
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5] y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
axes = Axes( axes = Axes(
x_range=x_range, x_range=x_range,
@ -330,23 +341,27 @@ class IrisDatasetPlot(VGroup):
axes_group.add(axes) axes_group.add(axes)
# make axis labels # make axis labels
# x_label # x_label
x_label = Text(labels[0], font=font) \ x_label = (
.match_y(axes.get_axes()[0]) \ Text(labels[0], font=font)
.shift([0.5, -0.75, 0]) \ .match_y(axes.get_axes()[0])
.shift([0.5, -0.75, 0])
.scale(font_scale) .scale(font_scale)
)
axes_group.add(x_label) axes_group.add(x_label)
# y_label # y_label
y_label = Text(labels[1], font=font) \ y_label = (
.match_x(axes.get_axes()[1]) \ Text(labels[1], font=font)
.shift([-0.75, 0, 0]) \ .match_x(axes.get_axes()[1])
.rotate(np.pi / 2) \ .shift([-0.75, 0, 0])
.rotate(np.pi / 2)
.scale(font_scale) .scale(font_scale)
)
axes_group.add(y_label) axes_group.add(y_label)
return axes_group return axes_group
class DecisionTreeSurface(VGroup):
class DecisionTreeSurface(VGroup):
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]): def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
# take the tree and construct the surface from it # take the tree and construct the surface from it
self.tree_clf = tree_clf self.tree_clf = tree_clf
@ -363,11 +378,7 @@ class DecisionTreeSurface(VGroup):
bottom = np.amin(self.data[:, 1]) - 0.2 bottom = np.amin(self.data[:, 1]) - 0.2
maxrange = [left, right, bottom, top] maxrange = [left, right, bottom, top]
rectangles = compute_decision_areas( rectangles = compute_decision_areas(
self.tree_clf, self.tree_clf, maxrange, x=0, y=1, n_features=2
maxrange,
x=0,
y=1,
n_features=2
) )
# turn the rectangle objects into manim rectangles # turn the rectangle objects into manim rectangles
def convert_rectangle_to_polygon(rect): def convert_rectangle_to_polygon(rect):
@ -381,9 +392,16 @@ class DecisionTreeSurface(VGroup):
bottom_right_coord = self.axes.coords_to_point(*bottom_right) bottom_right_coord = self.axes.coords_to_point(*bottom_right)
top_right_coord = self.axes.coords_to_point(*top_right) top_right_coord = self.axes.coords_to_point(*top_right)
top_left_coord = self.axes.coords_to_point(*top_left) top_left_coord = self.axes.coords_to_point(*top_left)
points = [bottom_left_coord, bottom_right_coord, top_right_coord, top_left_coord] points = [
bottom_left_coord,
bottom_right_coord,
top_right_coord,
top_left_coord,
]
# construct a polygon object from those manim coordinates # construct a polygon object from those manim coordinates
rectangle = Polygon(*points, color=color, fill_opacity=0.3, stroke_opacity=0.0) rectangle = Polygon(
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
)
return rectangle return rectangle
manim_rectangles = [] manim_rectangles = []
@ -392,7 +410,9 @@ class DecisionTreeSurface(VGroup):
rectangle = convert_rectangle_to_polygon(rect) rectangle = convert_rectangle_to_polygon(rect)
manim_rectangles.append(rectangle) manim_rectangles.append(rectangle)
manim_rectangles = merge_overlapping_polygons(manim_rectangles, colors=[BLUE, GREEN, ORANGE]) manim_rectangles = merge_overlapping_polygons(
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
)
return manim_rectangles return manim_rectangles
@ -416,6 +436,7 @@ class DecisionTreeSurface(VGroup):
return animation_group return animation_group
class DecisionTreeContainer(OneToOneSync): class DecisionTreeContainer(OneToOneSync):
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding""" """Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""

View File

@ -1,11 +1,14 @@
def compute_node_depths(tree): def compute_node_depths(tree):
"""Computes the depths of nodes for level order traversal""" """Computes the depths of nodes for level order traversal"""
def depth(node_index, current_node_index=0): def depth(node_index, current_node_index=0):
"""Compute the height of a node""" """Compute the height of a node"""
if current_node_index == node_index: if current_node_index == node_index:
return 0 return 0
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]: elif (
tree.children_left[current_node_index]
== tree.children_right[current_node_index]
):
return -1 return -1
else: else:
# Compute the height of each subtree # Compute the height of each subtree
@ -23,13 +26,18 @@ def compute_node_depths(tree):
return node_depths return node_depths
def compute_level_order_traversal(tree): def compute_level_order_traversal(tree):
"""Computes level order traversal of a sklearn tree""" """Computes level order traversal of a sklearn tree"""
def depth(node_index, current_node_index=0): def depth(node_index, current_node_index=0):
"""Compute the height of a node""" """Compute the height of a node"""
if current_node_index == node_index: if current_node_index == node_index:
return 0 return 0
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]: elif (
tree.children_left[current_node_index]
== tree.children_right[current_node_index]
):
return -1 return -1
else: else:
# Compute the height of each subtree # Compute the height of each subtree
@ -49,6 +57,7 @@ def compute_level_order_traversal(tree):
return sorted_inds return sorted_inds
def compute_node_to_parent_mapping(tree): def compute_node_to_parent_mapping(tree):
"""Returns a hashmap mapping node indices to their parent indices""" """Returns a hashmap mapping node indices to their parent indices"""
node_to_parent = {0: -1} # Root has no parent node_to_parent = {0: -1} # Root has no parent

View File

@ -24,6 +24,7 @@ class GriddedRectangle(VGroup):
): ):
super().__init__() super().__init__()
# Fields # Fields
self.color = color
self.mark_paths_closed = mark_paths_closed self.mark_paths_closed = mark_paths_closed
self.close_new_points = close_new_points self.close_new_points = close_new_points
self.grid_xstep = grid_xstep self.grid_xstep = grid_xstep
@ -33,8 +34,6 @@ class GriddedRectangle(VGroup):
self.grid_stroke_opacity = grid_stroke_opacity if show_grid_lines else 0.0 self.grid_stroke_opacity = grid_stroke_opacity if show_grid_lines else 0.0
self.stroke_width = stroke_width self.stroke_width = stroke_width
self.rotation_angles = [0, 0, 0] self.rotation_angles = [0, 0, 0]
self.rectangle_width = width
self.rectangle_height = height
self.show_grid_lines = show_grid_lines self.show_grid_lines = show_grid_lines
# Make rectangle # Make rectangle
self.rectangle = Rectangle( self.rectangle = Rectangle(
@ -129,3 +128,9 @@ class GriddedRectangle(VGroup):
normal_vector = np.cross((vertex_1 - vertex_2), (vertex_1 - vertex_3)) normal_vector = np.cross((vertex_1 - vertex_2), (vertex_1 - vertex_3))
return normal_vector return normal_vector
def set_color(self, color):
"""Sets the color of the gridded rectangle"""
self.color = color
self.rectangle.set_color(color)
self.rectangle.set_stroke_color(color)

View File

@ -3,22 +3,22 @@ import numpy as np
from PIL import Image from PIL import Image
class GrayscaleImageMobject(ImageMobject): class GrayscaleImageMobject(Group):
"""Mobject for creating images in Manim from numpy arrays""" """Mobject for creating images in Manim from numpy arrays"""
def __init__(self, numpy_image, height=2.3): def __init__(self, numpy_image, height=2.3):
super().__init__()
self.numpy_image = numpy_image self.numpy_image = numpy_image
assert len(np.shape(self.numpy_image)) == 2 assert len(np.shape(self.numpy_image)) == 2
input_image = self.numpy_image[None, :, :] input_image = self.numpy_image[None, :, :]
# Convert grayscale to rgb version of grayscale # Convert grayscale to rgb version of grayscale
input_image = np.repeat(input_image, 3, axis=0) input_image = np.repeat(input_image, 3, axis=0)
input_image = np.rollaxis(input_image, 0, start=3) input_image = np.rollaxis(input_image, 0, start=3)
self.image_mobject = ImageMobject(input_image, image_mode="RBG")
super().__init__(input_image, image_mode="RGB") self.add(self.image_mobject)
self.image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
self.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) self.image_mobject.scale_to_fit_height(height)
self.scale_to_fit_height(height)
@classmethod @classmethod
def from_path(cls, path, height=2.3): def from_path(cls, path, height=2.3):
@ -32,6 +32,20 @@ class GrayscaleImageMobject(ImageMobject):
def create(self, run_time=2): def create(self, run_time=2):
return FadeIn(self) return FadeIn(self)
def scale(self, scale_factor, **kwargs):
"""Scales the image mobject"""
# super().scale(scale_factor)
# height = self.height
self.image_mobject.scale(scale_factor)
# self.scale_to_fit_height(2)
# self.apply_points_function_about_point(
# lambda points: scale_factor * points, **kwargs
# )
def set_opacity(self, opacity):
"""Set the opacity"""
self.image_mobject.set_opacity(opacity)
class LabeledColorImage(Group): class LabeledColorImage(Group):
"""Labeled Color Image""" """Labeled Color Image"""

View File

@ -6,10 +6,12 @@ from manim import *
import random import random
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward from manim_ml.neural_network.layers.feed_forward_to_feed_forward import (
FeedForwardToFeedForward,
)
class XMark(VGroup): class XMark(VGroup):
def __init__(self, stroke_width=1.0, color=GRAY): def __init__(self, stroke_width=1.0, color=GRAY):
super().__init__() super().__init__()
line_one = Line( line_one = Line(
@ -17,7 +19,7 @@ class XMark(VGroup):
[0.1, -0.1, 0], [0.1, -0.1, 0],
stroke_width=1.0, stroke_width=1.0,
stroke_color=color, stroke_color=color,
z_index=4 z_index=4,
) )
self.add(line_one) self.add(line_one)
line_two = Line( line_two = Line(
@ -25,10 +27,11 @@ class XMark(VGroup):
[0.1, 0.1, 0], [0.1, 0.1, 0],
stroke_width=1.0, stroke_width=1.0,
stroke_color=color, stroke_color=color,
z_index=4 z_index=4,
) )
self.add(line_two) self.add(line_two)
def get_edges_to_drop_out(layer: FeedForwardToFeedForward, layers_to_nodes_to_drop_out): def get_edges_to_drop_out(layer: FeedForwardToFeedForward, layers_to_nodes_to_drop_out):
"""Returns edges to drop out for a given FeedForwardToFeedForward layer""" """Returns edges to drop out for a given FeedForwardToFeedForward layer"""
prev_layer = layer.input_layer prev_layer = layer.input_layer
@ -43,18 +46,21 @@ def get_edges_to_drop_out(layer: FeedForwardToFeedForward, layers_to_nodes_to_dr
prev_node_index = int(edge_index / next_layer.num_nodes) prev_node_index = int(edge_index / next_layer.num_nodes)
next_node_index = edge_index % next_layer.num_nodes next_node_index = edge_index % next_layer.num_nodes
# Check if the edges should be dropped out # Check if the edges should be dropped out
if prev_node_index in prev_layer_nodes_to_dropout \ if (
or next_node_index in next_layer_nodes_to_dropout: prev_node_index in prev_layer_nodes_to_dropout
or next_node_index in next_layer_nodes_to_dropout
):
edges_to_dropout.append(edge) edges_to_dropout.append(edge)
edge_indices_to_dropout.append(edge_index) edge_indices_to_dropout.append(edge_index)
return edges_to_dropout, edge_indices_to_dropout return edges_to_dropout, edge_indices_to_dropout
def make_pre_dropout_animation( def make_pre_dropout_animation(
neural_network, neural_network,
layers_to_nodes_to_drop_out, layers_to_nodes_to_drop_out,
dropped_out_color=GRAY, dropped_out_color=GRAY,
dropped_out_opacity=0.2 dropped_out_opacity=0.2,
): ):
"""Makes an animation that sets up the NN layer for dropout""" """Makes an animation that sets up the NN layer for dropout"""
animations = [] animations = []
@ -70,8 +76,7 @@ def make_pre_dropout_animation(
layers_to_edges_to_dropout = {} layers_to_edges_to_dropout = {}
for layer in feed_forward_to_feed_forward_layers: for layer in feed_forward_to_feed_forward_layers:
layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out( layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out(
layer, layer, layers_to_nodes_to_drop_out
layers_to_nodes_to_drop_out
) )
# Dim the colors of the edges # Dim the colors of the edges
dim_edge_colors_animations = [] dim_edge_colors_animations = []
@ -92,12 +97,9 @@ def make_pre_dropout_animation(
) )
""" """
dim_edge_colors_animations.append( dim_edge_colors_animations.append(FadeOut(edge))
FadeOut(edge)
)
dim_edge_colors_animation = AnimationGroup( dim_edge_colors_animation = AnimationGroup(
*dim_edge_colors_animations, *dim_edge_colors_animations, lag_ratio=0.0
lag_ratio=0.0
) )
# Dim the colors of the nodes # Dim the colors of the nodes
dim_nodes_animations = [] dim_nodes_animations = []
@ -113,10 +115,7 @@ def make_pre_dropout_animation(
create_x = Create(x_mark) create_x = Create(x_mark)
dim_nodes_animations.append(create_x) dim_nodes_animations.append(create_x)
dim_nodes_animation = AnimationGroup( dim_nodes_animation = AnimationGroup(*dim_nodes_animations, lag_ratio=0.0)
*dim_nodes_animations,
lag_ratio=0.0
)
animation_group = AnimationGroup( animation_group = AnimationGroup(
dim_edge_colors_animation, dim_edge_colors_animation,
@ -125,6 +124,7 @@ def make_pre_dropout_animation(
return animation_group, x_marks return animation_group, x_marks
def make_post_dropout_animation( def make_post_dropout_animation(
neural_network, neural_network,
layers_to_nodes_to_drop_out, layers_to_nodes_to_drop_out,
@ -143,8 +143,7 @@ def make_post_dropout_animation(
layers_to_edges_to_dropout = {} layers_to_edges_to_dropout = {}
for layer in feed_forward_to_feed_forward_layers: for layer in feed_forward_to_feed_forward_layers:
layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out( layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out(
layer, layer, layers_to_nodes_to_drop_out
layers_to_nodes_to_drop_out
) )
# Remove the x marks # Remove the x marks
uncreate_animations = [] uncreate_animations = []
@ -152,10 +151,7 @@ def make_post_dropout_animation(
uncreate_x_mark = Uncreate(x_mark) uncreate_x_mark = Uncreate(x_mark)
uncreate_animations.append(uncreate_x_mark) uncreate_animations.append(uncreate_x_mark)
uncreate_x_marks = AnimationGroup( uncreate_x_marks = AnimationGroup(*uncreate_animations, lag_ratio=0.0)
*uncreate_animations,
lag_ratio=0.0
)
# Add the edges back # Add the edges back
create_edge_animations = [] create_edge_animations = []
for layer in layers_to_edges_to_dropout.keys(): for layer in layers_to_edges_to_dropout.keys():
@ -164,20 +160,12 @@ def make_post_dropout_animation(
for edge_index, edge in enumerate(edges_to_drop_out): for edge_index, edge in enumerate(edges_to_drop_out):
edge_copy = edge.copy() edge_copy = edge.copy()
edges_to_drop_out[edge_index] = edge_copy edges_to_drop_out[edge_index] = edge_copy
create_edge_animations.append( create_edge_animations.append(FadeIn(edge_copy))
FadeIn(edge_copy)
)
create_edge_animation = AnimationGroup( create_edge_animation = AnimationGroup(*create_edge_animations, lag_ratio=0.0)
*create_edge_animations,
lag_ratio=0.0 return AnimationGroup(uncreate_x_marks, create_edge_animation, lag_ratio=0.0)
)
return AnimationGroup(
uncreate_x_marks,
create_edge_animation,
lag_ratio=0.0
)
def make_forward_pass_with_dropout_animation( def make_forward_pass_with_dropout_animation(
neural_network, neural_network,
@ -195,26 +183,16 @@ def make_forward_pass_with_dropout_animation(
) )
# Iterate through network and get feed forward layers # Iterate through network and get feed forward layers
for layer in feed_forward_layers: for layer in feed_forward_layers:
layer_args[layer] = { layer_args[layer] = {"dropout_node_indices": layers_to_nodes_to_drop_out[layer]}
"dropout_node_indices": layers_to_nodes_to_drop_out[layer]
}
for layer in feed_forward_to_feed_forward_layers: for layer in feed_forward_to_feed_forward_layers:
_, edge_indices = get_edges_to_drop_out( _, edge_indices = get_edges_to_drop_out(layer, layers_to_nodes_to_drop_out)
layer, layer_args[layer] = {"edge_indices_to_dropout": edge_indices}
layers_to_nodes_to_drop_out
) return neural_network.make_forward_pass_animation(layer_args=layer_args)
layer_args[layer] = {
"edge_indices_to_dropout": edge_indices
}
return neural_network.make_forward_pass_animation(
layer_args=layer_args
)
def make_neural_network_dropout_animation( def make_neural_network_dropout_animation(
neural_network, neural_network, dropout_rate=0.5, do_forward_pass=True
dropout_rate=0.5,
do_forward_pass=True
): ):
""" """
Makes a dropout animation for a given neural network. Makes a dropout animation for a given neural network.
@ -247,26 +225,19 @@ def make_neural_network_dropout_animation(
layers_to_nodes_to_drop_out[feed_forward_layer] = nodes_to_drop_out layers_to_nodes_to_drop_out[feed_forward_layer] = nodes_to_drop_out
# Make the animation # Make the animation
pre_dropout_animation, x_marks = make_pre_dropout_animation( pre_dropout_animation, x_marks = make_pre_dropout_animation(
neural_network, neural_network, layers_to_nodes_to_drop_out
layers_to_nodes_to_drop_out
) )
if do_forward_pass: if do_forward_pass:
forward_pass_animation = make_forward_pass_with_dropout_animation( forward_pass_animation = make_forward_pass_with_dropout_animation(
neural_network, neural_network, layers_to_nodes_to_drop_out
layers_to_nodes_to_drop_out
) )
else: else:
forward_pass_animation = AnimationGroup() forward_pass_animation = AnimationGroup()
post_dropout_animation = make_post_dropout_animation( post_dropout_animation = make_post_dropout_animation(
neural_network, neural_network, layers_to_nodes_to_drop_out, x_marks
layers_to_nodes_to_drop_out,
x_marks
) )
# Combine the animations into one # Combine the animations into one
return Succession( return Succession(
pre_dropout_animation, pre_dropout_animation, forward_pass_animation, post_dropout_animation
forward_pass_animation,
post_dropout_animation
) )

View File

@ -51,13 +51,6 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
about_point=self.get_center(), about_point=self.get_center(),
axis=ThreeDLayer.rotation_axis, axis=ThreeDLayer.rotation_axis,
) )
"""
self.rotate(
ThreeDLayer.three_d_y_rotation,
about_point=self.get_center(),
axis=[0, 1, 0]
)
"""
def construct_feature_maps(self): def construct_feature_maps(self):
"""Creates the neural network layer""" """Creates the neural network layer"""

View File

@ -260,6 +260,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
pulse_color=ORANGE, pulse_color=ORANGE,
cell_width=0.2, cell_width=0.2,
show_grid_lines=True, show_grid_lines=True,
highlight_color=ORANGE,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -284,6 +285,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
self.line_color = line_color self.line_color = line_color
self.pulse_color = pulse_color self.pulse_color = pulse_color
self.show_grid_lines = show_grid_lines self.show_grid_lines = show_grid_lines
self.highlight_color = highlight_color
def get_rotated_shift_vectors(self): def get_rotated_shift_vectors(self):
""" """
@ -344,11 +346,11 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
animations.append(FadeOut(filters)) animations.append(FadeOut(filters))
return Succession(*animations, lag_ratio=1.0) return Succession(*animations, lag_ratio=1.0)
def animate_filters_one_at_a_time(self): def animate_filters_one_at_a_time(self, highlight_active_feature_map=False):
"""Animates each of the filters one at a time""" """Animates each of the filters one at a time"""
animations = [] animations = []
output_feature_maps = self.output_layer.feature_maps output_feature_maps = self.output_layer.feature_maps
for filter_index in range(len(output_feature_maps)): for feature_map_index in range(len(output_feature_maps)):
# Make filters # Make filters
filters = Filters( filters = Filters(
self.input_layer, self.input_layer,
@ -356,9 +358,28 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
line_color=self.color, line_color=self.color,
cell_width=self.cell_width, cell_width=self.cell_width,
show_grid_lines=self.show_grid_lines, show_grid_lines=self.show_grid_lines,
output_feature_map_to_connect=filter_index, # None means all at once output_feature_map_to_connect=feature_map_index, # None means all at once
) )
animations.append(Create(filters)) animations.append(Create(filters))
# Highlight the feature map
if highlight_active_feature_map:
feature_map = output_feature_maps[feature_map_index]
original_feature_map_color = feature_map.color
# Change the output feature map colors
change_color_animations = []
change_color_animations.append(
ApplyMethod(feature_map.set_color, self.highlight_color)
)
# Change the input feature map colors
input_feature_maps = self.input_layer.feature_maps
for input_feature_map in input_feature_maps:
change_color_animations.append(
ApplyMethod(input_feature_map.set_color, self.highlight_color)
)
# Combine the animations
animations.append(
AnimationGroup(*change_color_animations, lag_ratio=0.0)
)
# Get the rotated shift vectors # Get the rotated shift vectors
right_shift, down_shift = self.get_rotated_shift_vectors() right_shift, down_shift = self.get_rotated_shift_vectors()
left_shift = -1 * right_shift left_shift = -1 * right_shift
@ -403,11 +424,36 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
animations.append(shift_animation) animations.append(shift_animation)
# Remove the filters # Remove the filters
animations.append(FadeOut(filters)) animations.append(FadeOut(filters))
# Un-highlight the feature map
if highlight_active_feature_map:
feature_map = output_feature_maps[feature_map_index]
# Change the output feature map colors
change_color_animations = []
change_color_animations.append(
ApplyMethod(feature_map.set_color, original_feature_map_color)
)
# Change the input feature map colors
input_feature_maps = self.input_layer.feature_maps
for input_feature_map in input_feature_maps:
change_color_animations.append(
ApplyMethod(
input_feature_map.set_color, original_feature_map_color
)
)
# Combine the animations
animations.append(
AnimationGroup(*change_color_animations, lag_ratio=0.0)
)
return Succession(*animations, lag_ratio=1.0) return Succession(*animations, lag_ratio=1.0)
def make_forward_pass_animation( def make_forward_pass_animation(
self, layer_args={}, all_filters_at_once=False, run_time=10.5, **kwargs self,
layer_args={},
all_filters_at_once=False,
highlight_active_feature_map=False,
run_time=10.5,
**kwargs,
): ):
"""Forward pass animation from conv2d to conv2d""" """Forward pass animation from conv2d to conv2d"""
print(f"All filters at once: {all_filters_at_once}") print(f"All filters at once: {all_filters_at_once}")
@ -415,7 +461,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
if all_filters_at_once: if all_filters_at_once:
return self.animate_filters_all_at_once() return self.animate_filters_all_at_once()
else: else:
return self.animate_filters_one_at_a_time() return self.animate_filters_one_at_a_time(
highlight_active_feature_map=highlight_active_feature_map
)
def scale(self, scale_factor, **kwargs): def scale(self, scale_factor, **kwargs):
self.cell_width *= scale_factor self.cell_width *= scale_factor

View File

@ -1,6 +1,7 @@
from manim import * from manim import *
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
class FeedForwardLayer(VGroupNeuralNetworkLayer): class FeedForwardLayer(VGroupNeuralNetworkLayer):
"""Handles rendering a layer for a neural network""" """Handles rendering a layer for a neural network"""
@ -78,16 +79,10 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
# Make highlight animation # Make highlight animation
succession = Succession( succession = Succession(
ApplyMethod( ApplyMethod(
nodes_to_highlight.set_color, nodes_to_highlight.set_color, self.animation_dot_color, run_time=0.25
self.animation_dot_color,
run_time=0.25
), ),
Wait(1.0), Wait(1.0),
ApplyMethod( ApplyMethod(nodes_to_highlight.set_color, self.node_color, run_time=0.25),
nodes_to_highlight.set_color,
self.node_color,
run_time=0.25
),
) )
return succession return succession
@ -97,8 +92,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
if "dropout_node_indices" in layer_args: if "dropout_node_indices" in layer_args:
# Drop out certain nodes # Drop out certain nodes
return self.make_dropout_forward_pass_animation( return self.make_dropout_forward_pass_animation(
layer_args=layer_args, layer_args=layer_args, **kwargs
**kwargs
) )
else: else:
# Make highlight animation # Make highlight animation

View File

@ -68,19 +68,20 @@ class FeedForwardToFeedForward(ConnectiveLayer):
return animation_group return animation_group
def make_forward_pass_animation( def make_forward_pass_animation(
self, self, layer_args={}, run_time=1, feed_forward_dropout=0.0, **kwargs
layer_args={},
run_time=1,
feed_forward_dropout=0.0,
**kwargs
): ):
"""Animation for passing information from one FeedForwardLayer to the next""" """Animation for passing information from one FeedForwardLayer to the next"""
path_animations = [] path_animations = []
dots = [] dots = []
for edge_index, edge in enumerate(self.edges): for edge_index, edge in enumerate(self.edges):
if not edge_index in layer_args["edge_indices_to_dropout"]: if (
not "edge_indices_to_dropout" in layer_args
or not edge_index in layer_args["edge_indices_to_dropout"]
):
dot = Dot( dot = Dot(
color=self.animation_dot_color, fill_opacity=1.0, radius=self.dot_radius color=self.animation_dot_color,
fill_opacity=1.0,
radius=self.dot_radius,
) )
# Add to dots group # Add to dots group
dots.append(dot) dots.append(dot)

View File

@ -56,6 +56,10 @@ class ImageLayer(NeuralNetworkLayer):
"""Override get right""" """Override get right"""
return self.image_mobject.get_right() return self.image_mobject.get_right()
def scale(self, scale_factor, **kwargs):
"""Scales the image mobject"""
self.image_mobject.scale(scale_factor)
@property @property
def width(self): def width(self):
return self.image_mobject.width return self.image_mobject.width

View File

@ -27,10 +27,10 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Maps image to convolutional layer""" """Maps image to convolutional layer"""
# Transform the image from the input layer to the # Transform the image from the input layer to the
num_image_channels = self.input_layer.num_channels num_image_channels = self.input_layer.num_channels
if num_image_channels == 3: if num_image_channels == 1 or num_image_channels == 3: # TODO fix this later
return self.rbg_image_animation()
elif num_image_channels == 1:
return self.grayscale_image_animation() return self.grayscale_image_animation()
elif num_image_channels == 3:
return self.rbg_image_animation()
else: else:
raise Exception( raise Exception(
f"Unrecognized number of image channels: {num_image_channels}" f"Unrecognized number of image channels: {num_image_channels}"
@ -43,7 +43,6 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
# TODO create image mobjects for each channel and transform # TODO create image mobjects for each channel and transform
# it to the feature maps of the output_layer # it to the feature maps of the output_layer
raise NotImplementedError() raise NotImplementedError()
pass
def grayscale_image_animation(self): def grayscale_image_animation(self):
"""Handles animation for 1 channel image""" """Handles animation for 1 channel image"""
@ -80,7 +79,7 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
# Scale the max of width or height to the # Scale the max of width or height to the
# width of the feature_map # width of the feature_map
max_width_height = max(image_mobject.width, image_mobject.height) max_width_height = max(image_mobject.width, image_mobject.height)
scale_factor = target_feature_map.rectangle_width / max_width_height scale_factor = target_feature_map.width / max_width_height
scale_image = ApplyMethod(image_mobject.scale, scale_factor, run_time=0.5) scale_image = ApplyMethod(image_mobject.scale, scale_factor, run_time=0.5)
# Move the image # Move the image
move_image = ApplyMethod(image_mobject.move_to, target_feature_map) move_image = ApplyMethod(image_mobject.move_to, target_feature_map)

View File

@ -69,10 +69,13 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
return super()._create_override() return super()._create_override()
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}(" + \ return (
f"input_layer={self.input_layer.__class__.__name__}," + \ f"{self.__class__.__name__}("
f"output_layer={self.output_layer.__class__.__name__}," + \ + f"input_layer={self.input_layer.__class__.__name__},"
")" + f"output_layer={self.output_layer.__class__.__name__},"
+ ")"
)
class BlankConnective(ConnectiveLayer): class BlankConnective(ConnectiveLayer):
"""Connective layer to be used when the given pair of layers is undefined""" """Connective layer to be used when the given pair of layers is undefined"""

View File

@ -22,6 +22,7 @@ from manim_ml.neural_network.neural_network_transformations import (
RemoveLayer, RemoveLayer,
) )
class NeuralNetwork(Group): class NeuralNetwork(Group):
"""Neural Network Visualization Container Class""" """Neural Network Visualization Container Class"""
@ -34,8 +35,8 @@ class NeuralNetwork(Group):
edge_width=2.5, edge_width=2.5,
dot_radius=0.03, dot_radius=0.03,
title=" ", title=" ",
three_d_phi=-70 * DEGREES, layout="linear",
three_d_theta=-80 * DEGREES, layout_direction="left_to_right",
): ):
super(Group, self).__init__() super(Group, self).__init__()
self.input_layers = ListGroup(*input_layers) self.input_layers = ListGroup(*input_layers)
@ -46,13 +47,12 @@ class NeuralNetwork(Group):
self.dot_radius = dot_radius self.dot_radius = dot_radius
self.title_text = title self.title_text = title
self.created = False self.created = False
# Make the layer fixed in frame if its not 3D self.layout = layout
ThreeDLayer.three_d_theta = three_d_theta self.layout_direction = layout_direction
ThreeDLayer.three_d_phi = three_d_phi
# TODO take layer_node_count [0, (1, 2), 0] # TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces # and make it have explicit distinct subspaces
# Place the layers # Place the layers
self._place_layers() self._place_layers(layout=layout, layout_direction=layout_direction)
self.connective_layers, self.all_layers = self._construct_connective_layers() self.connective_layers, self.all_layers = self._construct_connective_layers()
# Make overhead title # Make overhead title
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2) self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
@ -67,7 +67,7 @@ class NeuralNetwork(Group):
# Print neural network # Print neural network
print(repr(self)) print(repr(self))
def _place_layers(self): def _place_layers(self, layout="linear", layout_direction="top_to_bottom"):
"""Creates the neural network""" """Creates the neural network"""
# TODO implement more sophisticated custom layouts # TODO implement more sophisticated custom layouts
# Default: Linear layout # Default: Linear layout
@ -79,6 +79,7 @@ class NeuralNetwork(Group):
if isinstance(current_layer, EmbeddingLayer) or isinstance( if isinstance(current_layer, EmbeddingLayer) or isinstance(
previous_layer, EmbeddingLayer previous_layer, EmbeddingLayer
): ):
if layout_direction == "left_to_right":
shift_vector = np.array( shift_vector = np.array(
[ [
( (
@ -90,15 +91,53 @@ class NeuralNetwork(Group):
0, 0,
] ]
) )
else: elif layout_direction == "top_to_bottom":
shift_vector = np.array( shift_vector = np.array(
[ [
(previous_layer.get_width() / 2 + current_layer.get_width() / 2) 0,
-(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
else:
if layout_direction == "left_to_right":
shift_vector = np.array(
[
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing, + self.layer_spacing,
0, 0,
0, 0,
] ]
) )
elif layout_direction == "top_to_bottom":
shift_vector = np.array(
[
0,
-(
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
current_layer.shift(shift_vector) current_layer.shift(shift_vector)
def _construct_connective_layers(self): def _construct_connective_layers(self):
@ -178,7 +217,7 @@ class NeuralNetwork(Group):
current_layer_args = { current_layer_args = {
**before_layer_args, **before_layer_args,
**current_layer_args, **current_layer_args,
**after_layer_args **after_layer_args,
} }
else: else:
current_layer_args = {} current_layer_args = {}
@ -229,14 +268,18 @@ class NeuralNetwork(Group):
"""Overriden scale""" """Overriden scale"""
for layer in self.all_layers: for layer in self.all_layers:
layer.scale(scale_factor, **kwargs) layer.scale(scale_factor, **kwargs)
# super().scale(scale_factor) # Place layers with scaled spacing
self.layer_spacing *= scale_factor
self._place_layers(layout=self.layout, layout_direction=self.layout_direction)
def filter_layers(self, function): def filter_layers(self, function):
"""Filters layers of the network given function""" """Filters layers of the network given function"""
layers_to_return = [] layers_to_return = []
for layer in self.all_layers: for layer in self.all_layers:
func_out = function(layer) func_out = function(layer)
assert isinstance(func_out, bool), "Filter layers function returned a non-boolean type." assert isinstance(
func_out, bool
), "Filter layers function returned a non-boolean type."
if func_out: if func_out:
layers_to_return.append(layer) layers_to_return.append(layer)
@ -257,6 +300,7 @@ class NeuralNetwork(Group):
string_repr = "NeuralNetwork([\n" + inner_string + "])" string_repr = "NeuralNetwork([\n" + inner_string + "])"
return string_repr return string_repr
class FeedForwardNeuralNetwork(NeuralNetwork): class FeedForwardNeuralNetwork(NeuralNetwork):
"""NeuralNetwork with just feed forward layers""" """NeuralNetwork with just feed forward layers"""

View File

@ -4,6 +4,7 @@
from manim import * from manim import *
from manim_ml.neural_network.layers.util import get_connective_layer from manim_ml.neural_network.layers.util import get_connective_layer
class RemoveLayer(AnimationGroup): class RemoveLayer(AnimationGroup):
""" """
Animation for removing a layer from a neural network. Animation for removing a layer from a neural network.

View File

@ -1,27 +1,31 @@
from manim import * from manim import *
from manim_ml.decision_tree.decision_tree import DecisionTreeDiagram, DecisionTreeSurface, IrisDatasetPlot from manim_ml.decision_tree.decision_tree import (
DecisionTreeDiagram,
DecisionTreeSurface,
IrisDatasetPlot,
)
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets from sklearn import datasets
import sklearn import sklearn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def learn_iris_decision_tree(iris): def learn_iris_decision_tree(iris):
decision_tree = DecisionTreeClassifier( decision_tree = DecisionTreeClassifier(
random_state=1, random_state=1, max_depth=3, max_leaf_nodes=6
max_depth=3,
max_leaf_nodes=6
) )
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target) decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
# output the decisioin tree in some format # output the decisioin tree in some format
return decision_tree return decision_tree
def make_sklearn_tree(dataset, max_tree_depth=3): def make_sklearn_tree(dataset, max_tree_depth=3):
tree = learn_iris_decision_tree(dataset) tree = learn_iris_decision_tree(dataset)
feature_names = dataset.feature_names[0:2] feature_names = dataset.feature_names[0:2]
return tree, tree.tree_ return tree, tree.tree_
class DecisionTreeScene(Scene):
class DecisionTreeScene(Scene):
def construct(self): def construct(self):
"""Makes a decision tree object""" """Makes a decision tree object"""
iris_dataset = datasets.load_iris() iris_dataset = datasets.load_iris()
@ -36,15 +40,8 @@ class DecisionTreeScene(Scene):
"images/iris_dataset/VeriscolorFlower.jpeg", "images/iris_dataset/VeriscolorFlower.jpeg",
"images/iris_dataset/VirginicaFlower.jpeg", "images/iris_dataset/VirginicaFlower.jpeg",
], ],
class_names=[ class_names=["Setosa", "Veriscolor", "Virginica"],
"Setosa", feature_names=["Sepal Length", "Sepal Width"],
"Veriscolor",
"Virginica"
],
feature_names=[
"Sepal Length",
"Sepal Width"
]
) )
decision_tree.move_to(ORIGIN) decision_tree.move_to(ORIGIN)
create_decision_tree = Create(decision_tree, traversal_order="bfs") create_decision_tree = Create(decision_tree, traversal_order="bfs")
@ -53,7 +50,6 @@ class DecisionTreeScene(Scene):
class SurfacePlot(Scene): class SurfacePlot(Scene):
def construct(self): def construct(self):
iris_dataset = datasets.load_iris() iris_dataset = datasets.load_iris()
iris_dataset_plot = IrisDatasetPlot(iris_dataset) iris_dataset_plot = IrisDatasetPlot(iris_dataset)
@ -63,13 +59,7 @@ class SurfacePlot(Scene):
# make the decision tree classifier # make the decision tree classifier
decision_tree_classifier, sklearn_tree = make_sklearn_tree(iris_dataset) decision_tree_classifier, sklearn_tree = make_sklearn_tree(iris_dataset)
decision_tree_surface = DecisionTreeSurface( decision_tree_surface = DecisionTreeSurface(
decision_tree_classifier, decision_tree_classifier, iris_dataset.data, iris_dataset_plot.axes_group[0]
iris_dataset.data,
iris_dataset_plot.axes_group[0]
)
self.play(
Create(
decision_tree_surface
)
) )
self.play(Create(decision_tree_surface))
self.wait(1) self.wait(1)

View File

@ -1,5 +1,7 @@
from manim import * from manim import *
from manim_ml.neural_network.animations.dropout import make_neural_network_dropout_animation from manim_ml.neural_network.animations.dropout import (
make_neural_network_dropout_animation,
)
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.image import ImageLayer from manim_ml.neural_network.layers.image import ImageLayer
from PIL import Image from PIL import Image
@ -11,6 +13,7 @@ config.pixel_width = 1900
config.frame_height = 5.0 config.frame_height = 5.0
config.frame_width = 5.0 config.frame_width = 5.0
def make_code_snippet(): def make_code_snippet():
code_str = """ code_str = """
nn = NeuralNetwork([ nn = NeuralNetwork([
@ -40,17 +43,19 @@ def make_code_snippet():
return code return code
class DropoutNeuralNetworkScene(Scene): class DropoutNeuralNetworkScene(Scene):
def construct(self): def construct(self):
# Make nn # Make nn
nn = NeuralNetwork([ nn = NeuralNetwork(
[
FeedForwardLayer(3, rectangle_color=BLUE), FeedForwardLayer(3, rectangle_color=BLUE),
FeedForwardLayer(5, rectangle_color=BLUE), FeedForwardLayer(5, rectangle_color=BLUE),
FeedForwardLayer(3, rectangle_color=BLUE), FeedForwardLayer(3, rectangle_color=BLUE),
FeedForwardLayer(5, rectangle_color=BLUE), FeedForwardLayer(5, rectangle_color=BLUE),
FeedForwardLayer(4, rectangle_color=BLUE), FeedForwardLayer(4, rectangle_color=BLUE),
], ],
layer_spacing=0.4 layer_spacing=0.4,
) )
# Center the nn # Center the nn
nn.move_to(ORIGIN) nn.move_to(ORIGIN)
@ -63,13 +68,12 @@ class DropoutNeuralNetworkScene(Scene):
# Play animation # Play animation
self.play( self.play(
make_neural_network_dropout_animation( make_neural_network_dropout_animation(
nn, nn, dropout_rate=0.25, do_forward_pass=True
dropout_rate=0.25,
do_forward_pass=True
) )
) )
self.wait(1) self.wait(1)
if __name__ == "__main__": if __name__ == "__main__":
"""Render all scenes""" """Render all scenes"""
dropout_nn_scene = DropoutNeuralNetworkScene() dropout_nn_scene = DropoutNeuralNetworkScene()

44
tests/test_nn_scale.py Normal file
View File

@ -0,0 +1,44 @@
from manim import *
from PIL import Image
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.image import ImageLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 6.0
config.frame_width = 6.0
class CombinedScene(ThreeDScene):
def construct(self):
image = Image.open("../assets/mnist/digit.jpeg")
numpy_image = np.asarray(image)
# Make nn
nn = NeuralNetwork(
[
ImageLayer(numpy_image, height=1.5),
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
FeedForwardLayer(3),
],
layer_spacing=0.25,
)
# Center the nn
nn.move_to(ORIGIN)
nn.scale(1.3)
self.add(nn)
"""
self.play(
FadeIn(nn)
)
"""
# Play animation
forward_pass = nn.make_forward_pass_animation(
corner_pulses=False, all_filters_at_once=False, highlight_filters=True
)
self.wait(1)
self.play(forward_pass)