mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-17 19:49:18 +08:00
Reformatted the code using black, allowd for different orientation NNs, made an option for highlighting the active filter in a CNN forward pass.
This commit is contained in:
@ -0,0 +1,45 @@
|
|||||||
|
from manim import *
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.image import ImageLayer
|
||||||
|
from manim_ml.neural_network.layers.parent_layers import ThreeDLayer
|
||||||
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
|
# Make the specific scene
|
||||||
|
config.pixel_height = 1200
|
||||||
|
config.pixel_width = 800
|
||||||
|
config.frame_height = 6.0
|
||||||
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedScene(ThreeDScene):
|
||||||
|
def construct(self):
|
||||||
|
image = Image.open("../../assets/doggo.jpeg")
|
||||||
|
numpy_image = np.asarray(image)
|
||||||
|
# Rotate the Three D layer position
|
||||||
|
ThreeDLayer.rotation_angle = 15 * DEGREES
|
||||||
|
ThreeDLayer.rotation_axis = [1, -1.0, 0]
|
||||||
|
# Make nn
|
||||||
|
nn = NeuralNetwork(
|
||||||
|
[
|
||||||
|
ImageLayer(numpy_image, height=1.5),
|
||||||
|
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
|
||||||
|
Convolutional3DLayer(3, 5, 5, 1, 1, filter_spacing=0.18),
|
||||||
|
],
|
||||||
|
layer_spacing=0.25,
|
||||||
|
layout_direction="top_to_bottom",
|
||||||
|
)
|
||||||
|
# Center the nn
|
||||||
|
nn.move_to(ORIGIN)
|
||||||
|
nn.scale(1.5)
|
||||||
|
self.add(nn)
|
||||||
|
# Play animation
|
||||||
|
forward_pass = nn.make_forward_pass_animation(
|
||||||
|
corner_pulses=False,
|
||||||
|
all_filters_at_once=False,
|
||||||
|
highlight_active_feature_map=True,
|
||||||
|
)
|
||||||
|
self.wait(1)
|
||||||
|
self.play(forward_pass)
|
@ -3,8 +3,10 @@ import numpy as np
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from sklearn.tree import _tree as ctree
|
from sklearn.tree import _tree as ctree
|
||||||
|
|
||||||
|
|
||||||
class AABB:
|
class AABB:
|
||||||
"""Axis-aligned bounding box"""
|
"""Axis-aligned bounding box"""
|
||||||
|
|
||||||
def __init__(self, n_features):
|
def __init__(self, n_features):
|
||||||
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
||||||
|
|
||||||
@ -18,6 +20,7 @@ class AABB:
|
|||||||
|
|
||||||
return left, right
|
return left, right
|
||||||
|
|
||||||
|
|
||||||
def tree_bounds(tree, n_features=None):
|
def tree_bounds(tree, n_features=None):
|
||||||
"""Compute final decision rule for each node in tree"""
|
"""Compute final decision rule for each node in tree"""
|
||||||
if n_features is None:
|
if n_features is None:
|
||||||
@ -33,6 +36,7 @@ def tree_bounds(tree, n_features=None):
|
|||||||
queue.extend([l, r])
|
queue.extend([l, r])
|
||||||
return aabbs
|
return aabbs
|
||||||
|
|
||||||
|
|
||||||
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||||
"""Extract decision areas.
|
"""Extract decision areas.
|
||||||
|
|
||||||
@ -69,21 +73,27 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None)
|
|||||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||||
return rectangles
|
return rectangles
|
||||||
|
|
||||||
|
|
||||||
def plot_areas(rectangles):
|
def plot_areas(rectangles):
|
||||||
for rect in rectangles:
|
for rect in rectangles:
|
||||||
color = ['b', 'r'][int(rect[4])]
|
color = ["b", "r"][int(rect[4])]
|
||||||
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
||||||
rp = Rectangle([rect[0], rect[2]],
|
rp = Rectangle(
|
||||||
|
[rect[0], rect[2]],
|
||||||
rect[1] - rect[0],
|
rect[1] - rect[0],
|
||||||
rect[3] - rect[2], color=color, alpha=0.3)
|
rect[3] - rect[2],
|
||||||
|
color=color,
|
||||||
|
alpha=0.3,
|
||||||
|
)
|
||||||
plt.gca().add_artist(rp)
|
plt.gca().add_artist(rp)
|
||||||
|
|
||||||
|
|
||||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||||
# get all polygons of each color
|
# get all polygons of each color
|
||||||
polygon_dict = {
|
polygon_dict = {
|
||||||
str(BLUE).lower(): [],
|
str(BLUE).lower(): [],
|
||||||
str(GREEN).lower(): [],
|
str(GREEN).lower(): [],
|
||||||
str(ORANGE).lower():[]
|
str(ORANGE).lower(): [],
|
||||||
}
|
}
|
||||||
for polygon in all_polygons:
|
for polygon in all_polygons:
|
||||||
print(polygon_dict)
|
print(polygon_dict)
|
||||||
@ -143,8 +153,10 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
|||||||
# Remove implementation-markers from the polygon.
|
# Remove implementation-markers from the polygon.
|
||||||
poly = [point for point, _ in polygon]
|
poly = [point for point, _ in polygon]
|
||||||
for vertex in poly:
|
for vertex in poly:
|
||||||
if vertex in edges_h: edges_h.pop(vertex)
|
if vertex in edges_h:
|
||||||
if vertex in edges_v: edges_v.pop(vertex)
|
edges_h.pop(vertex)
|
||||||
|
if vertex in edges_v:
|
||||||
|
edges_v.pop(vertex)
|
||||||
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
||||||
return_polygons.append(polygon)
|
return_polygons.append(polygon)
|
||||||
return return_polygons
|
return return_polygons
|
||||||
|
@ -6,18 +6,23 @@
|
|||||||
TODO reimplement the decision 2D decision tree surface drawing.
|
TODO reimplement the decision 2D decision tree surface drawing.
|
||||||
"""
|
"""
|
||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.decision_tree.classification_areas import compute_decision_areas, merge_overlapping_polygons
|
from manim_ml.decision_tree.classification_areas import (
|
||||||
|
compute_decision_areas,
|
||||||
|
merge_overlapping_polygons,
|
||||||
|
)
|
||||||
import manim_ml.decision_tree.helpers as helpers
|
import manim_ml.decision_tree.helpers as helpers
|
||||||
from manim_ml.one_to_one_sync import OneToOneSync
|
from manim_ml.one_to_one_sync import OneToOneSync
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class LeafNode(Group):
|
class LeafNode(Group):
|
||||||
"""Leaf node in tree"""
|
"""Leaf node in tree"""
|
||||||
|
|
||||||
def __init__(self, class_index, display_type="image", class_image_paths=[],
|
def __init__(
|
||||||
class_colors=[]):
|
self, class_index, display_type="image", class_image_paths=[], class_colors=[]
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.display_type = display_type
|
self.display_type = display_type
|
||||||
self.class_image_paths = class_image_paths
|
self.class_image_paths = class_image_paths
|
||||||
@ -39,13 +44,14 @@ class LeafNode(Group):
|
|||||||
width=node.width + 0.05,
|
width=node.width + 0.05,
|
||||||
height=node.height + 0.05,
|
height=node.height + 0.05,
|
||||||
color=self.class_colors[class_index],
|
color=self.class_colors[class_index],
|
||||||
stroke_width=6
|
stroke_width=6,
|
||||||
)
|
)
|
||||||
rectangle.move_to(node.get_center())
|
rectangle.move_to(node.get_center())
|
||||||
rectangle.shift([-0.02, 0.02, 0])
|
rectangle.shift([-0.02, 0.02, 0])
|
||||||
self.add(rectangle)
|
self.add(rectangle)
|
||||||
self.add(node)
|
self.add(node)
|
||||||
|
|
||||||
|
|
||||||
class SplitNode(VGroup):
|
class SplitNode(VGroup):
|
||||||
"""Node for splitting decision in tree"""
|
"""Node for splitting decision in tree"""
|
||||||
|
|
||||||
@ -53,25 +59,24 @@ class SplitNode(VGroup):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
node_text = f"{feature}\n<= {threshold:.2f} cm"
|
node_text = f"{feature}\n<= {threshold:.2f} cm"
|
||||||
# Draw decision text
|
# Draw decision text
|
||||||
decision_text = Text(
|
decision_text = Text(node_text, color=WHITE)
|
||||||
node_text,
|
|
||||||
color=WHITE
|
|
||||||
)
|
|
||||||
# Draw the surrounding box
|
# Draw the surrounding box
|
||||||
bounding_box = SurroundingRectangle(
|
bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
|
||||||
decision_text,
|
|
||||||
buff=0.3,
|
|
||||||
color=WHITE
|
|
||||||
)
|
|
||||||
self.add(bounding_box)
|
self.add(bounding_box)
|
||||||
self.add(decision_text)
|
self.add(decision_text)
|
||||||
|
|
||||||
|
|
||||||
class DecisionTreeDiagram(Group):
|
class DecisionTreeDiagram(Group):
|
||||||
"""Decision Tree Diagram Class for Manim"""
|
"""Decision Tree Diagram Class for Manim"""
|
||||||
|
|
||||||
def __init__(self, sklearn_tree, feature_names=None,
|
def __init__(
|
||||||
class_names=None, class_images_paths=None,
|
self,
|
||||||
class_colors=[RED, GREEN, BLUE]):
|
sklearn_tree,
|
||||||
|
feature_names=None,
|
||||||
|
class_names=None,
|
||||||
|
class_images_paths=None,
|
||||||
|
class_colors=[RED, GREEN, BLUE],
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tree = sklearn_tree
|
self.tree = sklearn_tree
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
@ -87,14 +92,13 @@ class DecisionTreeDiagram(Group):
|
|||||||
node_index,
|
node_index,
|
||||||
):
|
):
|
||||||
"""Make node"""
|
"""Make node"""
|
||||||
is_split_node = self.tree.children_left[node_index] != self.tree.children_right[node_index]
|
is_split_node = (
|
||||||
|
self.tree.children_left[node_index] != self.tree.children_right[node_index]
|
||||||
|
)
|
||||||
if is_split_node:
|
if is_split_node:
|
||||||
node_feature = self.tree.feature[node_index]
|
node_feature = self.tree.feature[node_index]
|
||||||
node_threshold = self.tree.threshold[node_index]
|
node_threshold = self.tree.threshold[node_index]
|
||||||
node = SplitNode(
|
node = SplitNode(self.feature_names[node_feature], node_threshold)
|
||||||
self.feature_names[node_feature],
|
|
||||||
node_threshold
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Get the most abundant class for the given leaf node
|
# Get the most abundant class for the given leaf node
|
||||||
# Make the leaf node object
|
# Make the leaf node object
|
||||||
@ -102,7 +106,7 @@ class DecisionTreeDiagram(Group):
|
|||||||
node = LeafNode(
|
node = LeafNode(
|
||||||
class_index=tree_class_index,
|
class_index=tree_class_index,
|
||||||
class_colors=self.class_colors,
|
class_colors=self.class_colors,
|
||||||
class_image_paths=self.class_image_paths
|
class_image_paths=self.class_image_paths,
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@ -113,11 +117,7 @@ class DecisionTreeDiagram(Group):
|
|||||||
bottom_node_top_location = bottom.get_center()
|
bottom_node_top_location = bottom.get_center()
|
||||||
bottom_node_top_location[1] += bottom.height / 2
|
bottom_node_top_location[1] += bottom.height / 2
|
||||||
|
|
||||||
line = Line(
|
line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
|
||||||
top_node_bottom_location,
|
|
||||||
bottom_node_top_location,
|
|
||||||
color=WHITE
|
|
||||||
)
|
|
||||||
|
|
||||||
return line
|
return line
|
||||||
|
|
||||||
@ -143,31 +143,49 @@ class DecisionTreeDiagram(Group):
|
|||||||
# traverse tree
|
# traverse tree
|
||||||
def recurse(node_index, depth, direction, parent_object, parent_node):
|
def recurse(node_index, depth, direction, parent_object, parent_node):
|
||||||
# make the node object
|
# make the node object
|
||||||
is_leaf = self.tree.children_left[node_index] == self.tree.children_right[node_index]
|
is_leaf = (
|
||||||
|
self.tree.children_left[node_index]
|
||||||
|
== self.tree.children_right[node_index]
|
||||||
|
)
|
||||||
node_object = self._make_node(node_index=node_index)
|
node_object = self._make_node(node_index=node_index)
|
||||||
nodes_map[node_index] = node_object
|
nodes_map[node_index] = node_object
|
||||||
node_height = node_object.height
|
node_height = node_object.height
|
||||||
# set the node position
|
# set the node position
|
||||||
direction_factor = -1 if direction == "left" else 1
|
direction_factor = -1 if direction == "left" else 1
|
||||||
shift_right_amount = 0.9 * direction_factor * scale_factor * tree_width / (2 ** depth) / 2
|
shift_right_amount = (
|
||||||
|
0.9 * direction_factor * scale_factor * tree_width / (2**depth) / 2
|
||||||
|
)
|
||||||
if is_leaf:
|
if is_leaf:
|
||||||
shift_down_amount = -1.0 * scale_factor * node_height
|
shift_down_amount = -1.0 * scale_factor * node_height
|
||||||
else:
|
else:
|
||||||
shift_down_amount = -1.8 * scale_factor * node_height
|
shift_down_amount = -1.8 * scale_factor * node_height
|
||||||
node_object \
|
node_object.match_x(parent_object).match_y(parent_object).shift(
|
||||||
.match_x(parent_object) \
|
[shift_right_amount, shift_down_amount, 0]
|
||||||
.match_y(parent_object) \
|
)
|
||||||
.shift([shift_right_amount, shift_down_amount, 0])
|
|
||||||
tree_group.add(node_object)
|
tree_group.add(node_object)
|
||||||
# make a connection
|
# make a connection
|
||||||
connection = self._make_connection(parent_object, node_object, is_leaf=is_leaf)
|
connection = self._make_connection(
|
||||||
|
parent_object, node_object, is_leaf=is_leaf
|
||||||
|
)
|
||||||
edge_name = str(parent_node) + "," + str(node_index)
|
edge_name = str(parent_node) + "," + str(node_index)
|
||||||
edge_map[edge_name] = connection
|
edge_map[edge_name] = connection
|
||||||
tree_group.add(connection)
|
tree_group.add(connection)
|
||||||
# recurse
|
# recurse
|
||||||
if not is_leaf:
|
if not is_leaf:
|
||||||
recurse(self.tree.children_left[node_index], depth + 1, "left", node_object, node_index)
|
recurse(
|
||||||
recurse(self.tree.children_right[node_index], depth + 1, "right", node_object, node_index)
|
self.tree.children_left[node_index],
|
||||||
|
depth + 1,
|
||||||
|
"left",
|
||||||
|
node_object,
|
||||||
|
node_index,
|
||||||
|
)
|
||||||
|
recurse(
|
||||||
|
self.tree.children_right[node_index],
|
||||||
|
depth + 1,
|
||||||
|
"right",
|
||||||
|
node_object,
|
||||||
|
node_index,
|
||||||
|
)
|
||||||
|
|
||||||
recurse(self.tree.children_left[0], 1, "left", root_node, 0)
|
recurse(self.tree.children_left[0], 1, "left", root_node, 0)
|
||||||
recurse(self.tree.children_right[0], 1, "right", root_node, 0)
|
recurse(self.tree.children_right[0], 1, "right", root_node, 0)
|
||||||
@ -185,9 +203,7 @@ class DecisionTreeDiagram(Group):
|
|||||||
# Compute parent mapping
|
# Compute parent mapping
|
||||||
parent_mapping = helpers.compute_node_to_parent_mapping(self.tree)
|
parent_mapping = helpers.compute_node_to_parent_mapping(self.tree)
|
||||||
# Create the root node
|
# Create the root node
|
||||||
animations.append(
|
animations.append(Create(self.nodes_map[0]))
|
||||||
Create(self.nodes_map[0])
|
|
||||||
)
|
|
||||||
# Iterate through the nodes
|
# Iterate through the nodes
|
||||||
queue = [0]
|
queue = [0]
|
||||||
while len(queue) > 0:
|
while len(queue) > 0:
|
||||||
@ -211,21 +227,16 @@ class DecisionTreeDiagram(Group):
|
|||||||
FadeIn(right_node),
|
FadeIn(right_node),
|
||||||
Create(left_parent_edge),
|
Create(left_parent_edge),
|
||||||
Create(right_parent_edge),
|
Create(right_parent_edge),
|
||||||
lag_ratio=0.0
|
lag_ratio=0.0,
|
||||||
)
|
|
||||||
animations.append(
|
|
||||||
split_animation
|
|
||||||
)
|
)
|
||||||
|
animations.append(split_animation)
|
||||||
# Add the children to the queue
|
# Add the children to the queue
|
||||||
if left_child != -1:
|
if left_child != -1:
|
||||||
queue.append(left_child)
|
queue.append(left_child)
|
||||||
if right_child != -1:
|
if right_child != -1:
|
||||||
queue.append(right_child)
|
queue.append(right_child)
|
||||||
|
|
||||||
return AnimationGroup(
|
return AnimationGroup(*animations, lag_ratio=1.0)
|
||||||
*animations,
|
|
||||||
lag_ratio=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
@override_animation(Create)
|
@override_animation(Create)
|
||||||
def create_decision_tree(self, traversal_order="bfs"):
|
def create_decision_tree(self, traversal_order="bfs"):
|
||||||
@ -237,8 +248,8 @@ class DecisionTreeDiagram(Group):
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
||||||
|
|
||||||
class IrisDatasetPlot(VGroup):
|
|
||||||
|
|
||||||
|
class IrisDatasetPlot(VGroup):
|
||||||
def __init__(self, iris):
|
def __init__(self, iris):
|
||||||
points = iris.data[:, 0:2]
|
points = iris.data[:, 0:2]
|
||||||
labels = iris.feature_names
|
labels = iris.feature_names
|
||||||
@ -249,19 +260,13 @@ class IrisDatasetPlot(VGroup):
|
|||||||
self.axes_group = self._make_axes_group(points, labels)
|
self.axes_group = self._make_axes_group(points, labels)
|
||||||
# Make legend
|
# Make legend
|
||||||
self.legend_group = self._make_legend(
|
self.legend_group = self._make_legend(
|
||||||
[BLUE, ORANGE, GREEN],
|
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
||||||
iris.target_names,
|
|
||||||
self.axes_group
|
|
||||||
)
|
)
|
||||||
# Make title
|
# Make title
|
||||||
# title_text = "Iris Dataset Plot"
|
# title_text = "Iris Dataset Plot"
|
||||||
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
||||||
# Make all group
|
# Make all group
|
||||||
self.all_group = Group(
|
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
||||||
self.point_group,
|
|
||||||
self.axes_group,
|
|
||||||
self.legend_group
|
|
||||||
)
|
|
||||||
# scale the groups
|
# scale the groups
|
||||||
self.point_group.scale(1.6)
|
self.point_group.scale(1.6)
|
||||||
self.point_group.match_x(self.axes_group)
|
self.point_group.match_x(self.axes_group)
|
||||||
@ -279,7 +284,7 @@ class IrisDatasetPlot(VGroup):
|
|||||||
Create(self.axes_group, run_time=2),
|
Create(self.axes_group, run_time=2),
|
||||||
# add title
|
# add title
|
||||||
# Create(self.title),
|
# Create(self.title),
|
||||||
Create(self.legend_group)
|
Create(self.legend_group),
|
||||||
)
|
)
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
@ -300,7 +305,9 @@ class IrisDatasetPlot(VGroup):
|
|||||||
setosa = Text("Setosa", color=BLUE)
|
setosa = Text("Setosa", color=BLUE)
|
||||||
verisicolor = Text("Verisicolor", color=ORANGE)
|
verisicolor = Text("Verisicolor", color=ORANGE)
|
||||||
virginica = Text("Virginica", color=GREEN)
|
virginica = Text("Virginica", color=GREEN)
|
||||||
labels = VGroup(setosa, verisicolor, virginica).arrange(direction=RIGHT, aligned_edge=LEFT, buff=2.0)
|
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
||||||
|
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
||||||
|
)
|
||||||
labels.scale(0.5)
|
labels.scale(0.5)
|
||||||
legend_group.add(labels)
|
legend_group.add(labels)
|
||||||
# surrounding rectangle
|
# surrounding rectangle
|
||||||
@ -314,10 +321,14 @@ class IrisDatasetPlot(VGroup):
|
|||||||
|
|
||||||
return legend_group
|
return legend_group
|
||||||
|
|
||||||
def _make_axes_group(self, points, labels, font='Source Han Sans', font_scale=0.75):
|
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
|
||||||
axes_group = VGroup()
|
axes_group = VGroup()
|
||||||
# make the axes
|
# make the axes
|
||||||
x_range = [np.amin(points, axis=0)[0] - 0.2, np.amax(points, axis=0)[0] - 0.2, 0.5]
|
x_range = [
|
||||||
|
np.amin(points, axis=0)[0] - 0.2,
|
||||||
|
np.amax(points, axis=0)[0] - 0.2,
|
||||||
|
0.5,
|
||||||
|
]
|
||||||
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
||||||
axes = Axes(
|
axes = Axes(
|
||||||
x_range=x_range,
|
x_range=x_range,
|
||||||
@ -330,23 +341,27 @@ class IrisDatasetPlot(VGroup):
|
|||||||
axes_group.add(axes)
|
axes_group.add(axes)
|
||||||
# make axis labels
|
# make axis labels
|
||||||
# x_label
|
# x_label
|
||||||
x_label = Text(labels[0], font=font) \
|
x_label = (
|
||||||
.match_y(axes.get_axes()[0]) \
|
Text(labels[0], font=font)
|
||||||
.shift([0.5, -0.75, 0]) \
|
.match_y(axes.get_axes()[0])
|
||||||
|
.shift([0.5, -0.75, 0])
|
||||||
.scale(font_scale)
|
.scale(font_scale)
|
||||||
|
)
|
||||||
axes_group.add(x_label)
|
axes_group.add(x_label)
|
||||||
# y_label
|
# y_label
|
||||||
y_label = Text(labels[1], font=font) \
|
y_label = (
|
||||||
.match_x(axes.get_axes()[1]) \
|
Text(labels[1], font=font)
|
||||||
.shift([-0.75, 0, 0]) \
|
.match_x(axes.get_axes()[1])
|
||||||
.rotate(np.pi / 2) \
|
.shift([-0.75, 0, 0])
|
||||||
|
.rotate(np.pi / 2)
|
||||||
.scale(font_scale)
|
.scale(font_scale)
|
||||||
|
)
|
||||||
axes_group.add(y_label)
|
axes_group.add(y_label)
|
||||||
|
|
||||||
return axes_group
|
return axes_group
|
||||||
|
|
||||||
class DecisionTreeSurface(VGroup):
|
|
||||||
|
|
||||||
|
class DecisionTreeSurface(VGroup):
|
||||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||||
# take the tree and construct the surface from it
|
# take the tree and construct the surface from it
|
||||||
self.tree_clf = tree_clf
|
self.tree_clf = tree_clf
|
||||||
@ -363,11 +378,7 @@ class DecisionTreeSurface(VGroup):
|
|||||||
bottom = np.amin(self.data[:, 1]) - 0.2
|
bottom = np.amin(self.data[:, 1]) - 0.2
|
||||||
maxrange = [left, right, bottom, top]
|
maxrange = [left, right, bottom, top]
|
||||||
rectangles = compute_decision_areas(
|
rectangles = compute_decision_areas(
|
||||||
self.tree_clf,
|
self.tree_clf, maxrange, x=0, y=1, n_features=2
|
||||||
maxrange,
|
|
||||||
x=0,
|
|
||||||
y=1,
|
|
||||||
n_features=2
|
|
||||||
)
|
)
|
||||||
# turn the rectangle objects into manim rectangles
|
# turn the rectangle objects into manim rectangles
|
||||||
def convert_rectangle_to_polygon(rect):
|
def convert_rectangle_to_polygon(rect):
|
||||||
@ -381,9 +392,16 @@ class DecisionTreeSurface(VGroup):
|
|||||||
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
||||||
top_right_coord = self.axes.coords_to_point(*top_right)
|
top_right_coord = self.axes.coords_to_point(*top_right)
|
||||||
top_left_coord = self.axes.coords_to_point(*top_left)
|
top_left_coord = self.axes.coords_to_point(*top_left)
|
||||||
points = [bottom_left_coord, bottom_right_coord, top_right_coord, top_left_coord]
|
points = [
|
||||||
|
bottom_left_coord,
|
||||||
|
bottom_right_coord,
|
||||||
|
top_right_coord,
|
||||||
|
top_left_coord,
|
||||||
|
]
|
||||||
# construct a polygon object from those manim coordinates
|
# construct a polygon object from those manim coordinates
|
||||||
rectangle = Polygon(*points, color=color, fill_opacity=0.3, stroke_opacity=0.0)
|
rectangle = Polygon(
|
||||||
|
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
||||||
|
)
|
||||||
return rectangle
|
return rectangle
|
||||||
|
|
||||||
manim_rectangles = []
|
manim_rectangles = []
|
||||||
@ -392,7 +410,9 @@ class DecisionTreeSurface(VGroup):
|
|||||||
rectangle = convert_rectangle_to_polygon(rect)
|
rectangle = convert_rectangle_to_polygon(rect)
|
||||||
manim_rectangles.append(rectangle)
|
manim_rectangles.append(rectangle)
|
||||||
|
|
||||||
manim_rectangles = merge_overlapping_polygons(manim_rectangles, colors=[BLUE, GREEN, ORANGE])
|
manim_rectangles = merge_overlapping_polygons(
|
||||||
|
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
||||||
|
)
|
||||||
|
|
||||||
return manim_rectangles
|
return manim_rectangles
|
||||||
|
|
||||||
@ -416,6 +436,7 @@ class DecisionTreeSurface(VGroup):
|
|||||||
|
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
|
|
||||||
class DecisionTreeContainer(OneToOneSync):
|
class DecisionTreeContainer(OneToOneSync):
|
||||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
|
||||||
def compute_node_depths(tree):
|
def compute_node_depths(tree):
|
||||||
"""Computes the depths of nodes for level order traversal"""
|
"""Computes the depths of nodes for level order traversal"""
|
||||||
|
|
||||||
def depth(node_index, current_node_index=0):
|
def depth(node_index, current_node_index=0):
|
||||||
"""Compute the height of a node"""
|
"""Compute the height of a node"""
|
||||||
if current_node_index == node_index:
|
if current_node_index == node_index:
|
||||||
return 0
|
return 0
|
||||||
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
elif (
|
||||||
|
tree.children_left[current_node_index]
|
||||||
|
== tree.children_right[current_node_index]
|
||||||
|
):
|
||||||
return -1
|
return -1
|
||||||
else:
|
else:
|
||||||
# Compute the height of each subtree
|
# Compute the height of each subtree
|
||||||
@ -23,13 +26,18 @@ def compute_node_depths(tree):
|
|||||||
|
|
||||||
return node_depths
|
return node_depths
|
||||||
|
|
||||||
|
|
||||||
def compute_level_order_traversal(tree):
|
def compute_level_order_traversal(tree):
|
||||||
"""Computes level order traversal of a sklearn tree"""
|
"""Computes level order traversal of a sklearn tree"""
|
||||||
|
|
||||||
def depth(node_index, current_node_index=0):
|
def depth(node_index, current_node_index=0):
|
||||||
"""Compute the height of a node"""
|
"""Compute the height of a node"""
|
||||||
if current_node_index == node_index:
|
if current_node_index == node_index:
|
||||||
return 0
|
return 0
|
||||||
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
elif (
|
||||||
|
tree.children_left[current_node_index]
|
||||||
|
== tree.children_right[current_node_index]
|
||||||
|
):
|
||||||
return -1
|
return -1
|
||||||
else:
|
else:
|
||||||
# Compute the height of each subtree
|
# Compute the height of each subtree
|
||||||
@ -49,6 +57,7 @@ def compute_level_order_traversal(tree):
|
|||||||
|
|
||||||
return sorted_inds
|
return sorted_inds
|
||||||
|
|
||||||
|
|
||||||
def compute_node_to_parent_mapping(tree):
|
def compute_node_to_parent_mapping(tree):
|
||||||
"""Returns a hashmap mapping node indices to their parent indices"""
|
"""Returns a hashmap mapping node indices to their parent indices"""
|
||||||
node_to_parent = {0: -1} # Root has no parent
|
node_to_parent = {0: -1} # Root has no parent
|
||||||
|
@ -24,6 +24,7 @@ class GriddedRectangle(VGroup):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Fields
|
# Fields
|
||||||
|
self.color = color
|
||||||
self.mark_paths_closed = mark_paths_closed
|
self.mark_paths_closed = mark_paths_closed
|
||||||
self.close_new_points = close_new_points
|
self.close_new_points = close_new_points
|
||||||
self.grid_xstep = grid_xstep
|
self.grid_xstep = grid_xstep
|
||||||
@ -33,8 +34,6 @@ class GriddedRectangle(VGroup):
|
|||||||
self.grid_stroke_opacity = grid_stroke_opacity if show_grid_lines else 0.0
|
self.grid_stroke_opacity = grid_stroke_opacity if show_grid_lines else 0.0
|
||||||
self.stroke_width = stroke_width
|
self.stroke_width = stroke_width
|
||||||
self.rotation_angles = [0, 0, 0]
|
self.rotation_angles = [0, 0, 0]
|
||||||
self.rectangle_width = width
|
|
||||||
self.rectangle_height = height
|
|
||||||
self.show_grid_lines = show_grid_lines
|
self.show_grid_lines = show_grid_lines
|
||||||
# Make rectangle
|
# Make rectangle
|
||||||
self.rectangle = Rectangle(
|
self.rectangle = Rectangle(
|
||||||
@ -129,3 +128,9 @@ class GriddedRectangle(VGroup):
|
|||||||
normal_vector = np.cross((vertex_1 - vertex_2), (vertex_1 - vertex_3))
|
normal_vector = np.cross((vertex_1 - vertex_2), (vertex_1 - vertex_3))
|
||||||
|
|
||||||
return normal_vector
|
return normal_vector
|
||||||
|
|
||||||
|
def set_color(self, color):
|
||||||
|
"""Sets the color of the gridded rectangle"""
|
||||||
|
self.color = color
|
||||||
|
self.rectangle.set_color(color)
|
||||||
|
self.rectangle.set_stroke_color(color)
|
||||||
|
@ -3,22 +3,22 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class GrayscaleImageMobject(ImageMobject):
|
class GrayscaleImageMobject(Group):
|
||||||
"""Mobject for creating images in Manim from numpy arrays"""
|
"""Mobject for creating images in Manim from numpy arrays"""
|
||||||
|
|
||||||
def __init__(self, numpy_image, height=2.3):
|
def __init__(self, numpy_image, height=2.3):
|
||||||
|
super().__init__()
|
||||||
self.numpy_image = numpy_image
|
self.numpy_image = numpy_image
|
||||||
|
|
||||||
assert len(np.shape(self.numpy_image)) == 2
|
assert len(np.shape(self.numpy_image)) == 2
|
||||||
|
|
||||||
input_image = self.numpy_image[None, :, :]
|
input_image = self.numpy_image[None, :, :]
|
||||||
# Convert grayscale to rgb version of grayscale
|
# Convert grayscale to rgb version of grayscale
|
||||||
input_image = np.repeat(input_image, 3, axis=0)
|
input_image = np.repeat(input_image, 3, axis=0)
|
||||||
input_image = np.rollaxis(input_image, 0, start=3)
|
input_image = np.rollaxis(input_image, 0, start=3)
|
||||||
|
self.image_mobject = ImageMobject(input_image, image_mode="RBG")
|
||||||
super().__init__(input_image, image_mode="RGB")
|
self.add(self.image_mobject)
|
||||||
|
self.image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||||
self.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
self.image_mobject.scale_to_fit_height(height)
|
||||||
self.scale_to_fit_height(height)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_path(cls, path, height=2.3):
|
def from_path(cls, path, height=2.3):
|
||||||
@ -32,6 +32,20 @@ class GrayscaleImageMobject(ImageMobject):
|
|||||||
def create(self, run_time=2):
|
def create(self, run_time=2):
|
||||||
return FadeIn(self)
|
return FadeIn(self)
|
||||||
|
|
||||||
|
def scale(self, scale_factor, **kwargs):
|
||||||
|
"""Scales the image mobject"""
|
||||||
|
# super().scale(scale_factor)
|
||||||
|
# height = self.height
|
||||||
|
self.image_mobject.scale(scale_factor)
|
||||||
|
# self.scale_to_fit_height(2)
|
||||||
|
# self.apply_points_function_about_point(
|
||||||
|
# lambda points: scale_factor * points, **kwargs
|
||||||
|
# )
|
||||||
|
|
||||||
|
def set_opacity(self, opacity):
|
||||||
|
"""Set the opacity"""
|
||||||
|
self.image_mobject.set_opacity(opacity)
|
||||||
|
|
||||||
|
|
||||||
class LabeledColorImage(Group):
|
class LabeledColorImage(Group):
|
||||||
"""Labeled Color Image"""
|
"""Labeled Color Image"""
|
||||||
|
@ -6,10 +6,12 @@ from manim import *
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward
|
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import (
|
||||||
|
FeedForwardToFeedForward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XMark(VGroup):
|
class XMark(VGroup):
|
||||||
|
|
||||||
def __init__(self, stroke_width=1.0, color=GRAY):
|
def __init__(self, stroke_width=1.0, color=GRAY):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
line_one = Line(
|
line_one = Line(
|
||||||
@ -17,7 +19,7 @@ class XMark(VGroup):
|
|||||||
[0.1, -0.1, 0],
|
[0.1, -0.1, 0],
|
||||||
stroke_width=1.0,
|
stroke_width=1.0,
|
||||||
stroke_color=color,
|
stroke_color=color,
|
||||||
z_index=4
|
z_index=4,
|
||||||
)
|
)
|
||||||
self.add(line_one)
|
self.add(line_one)
|
||||||
line_two = Line(
|
line_two = Line(
|
||||||
@ -25,10 +27,11 @@ class XMark(VGroup):
|
|||||||
[0.1, 0.1, 0],
|
[0.1, 0.1, 0],
|
||||||
stroke_width=1.0,
|
stroke_width=1.0,
|
||||||
stroke_color=color,
|
stroke_color=color,
|
||||||
z_index=4
|
z_index=4,
|
||||||
)
|
)
|
||||||
self.add(line_two)
|
self.add(line_two)
|
||||||
|
|
||||||
|
|
||||||
def get_edges_to_drop_out(layer: FeedForwardToFeedForward, layers_to_nodes_to_drop_out):
|
def get_edges_to_drop_out(layer: FeedForwardToFeedForward, layers_to_nodes_to_drop_out):
|
||||||
"""Returns edges to drop out for a given FeedForwardToFeedForward layer"""
|
"""Returns edges to drop out for a given FeedForwardToFeedForward layer"""
|
||||||
prev_layer = layer.input_layer
|
prev_layer = layer.input_layer
|
||||||
@ -43,18 +46,21 @@ def get_edges_to_drop_out(layer: FeedForwardToFeedForward, layers_to_nodes_to_dr
|
|||||||
prev_node_index = int(edge_index / next_layer.num_nodes)
|
prev_node_index = int(edge_index / next_layer.num_nodes)
|
||||||
next_node_index = edge_index % next_layer.num_nodes
|
next_node_index = edge_index % next_layer.num_nodes
|
||||||
# Check if the edges should be dropped out
|
# Check if the edges should be dropped out
|
||||||
if prev_node_index in prev_layer_nodes_to_dropout \
|
if (
|
||||||
or next_node_index in next_layer_nodes_to_dropout:
|
prev_node_index in prev_layer_nodes_to_dropout
|
||||||
|
or next_node_index in next_layer_nodes_to_dropout
|
||||||
|
):
|
||||||
edges_to_dropout.append(edge)
|
edges_to_dropout.append(edge)
|
||||||
edge_indices_to_dropout.append(edge_index)
|
edge_indices_to_dropout.append(edge_index)
|
||||||
|
|
||||||
return edges_to_dropout, edge_indices_to_dropout
|
return edges_to_dropout, edge_indices_to_dropout
|
||||||
|
|
||||||
|
|
||||||
def make_pre_dropout_animation(
|
def make_pre_dropout_animation(
|
||||||
neural_network,
|
neural_network,
|
||||||
layers_to_nodes_to_drop_out,
|
layers_to_nodes_to_drop_out,
|
||||||
dropped_out_color=GRAY,
|
dropped_out_color=GRAY,
|
||||||
dropped_out_opacity=0.2
|
dropped_out_opacity=0.2,
|
||||||
):
|
):
|
||||||
"""Makes an animation that sets up the NN layer for dropout"""
|
"""Makes an animation that sets up the NN layer for dropout"""
|
||||||
animations = []
|
animations = []
|
||||||
@ -70,8 +76,7 @@ def make_pre_dropout_animation(
|
|||||||
layers_to_edges_to_dropout = {}
|
layers_to_edges_to_dropout = {}
|
||||||
for layer in feed_forward_to_feed_forward_layers:
|
for layer in feed_forward_to_feed_forward_layers:
|
||||||
layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out(
|
layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out(
|
||||||
layer,
|
layer, layers_to_nodes_to_drop_out
|
||||||
layers_to_nodes_to_drop_out
|
|
||||||
)
|
)
|
||||||
# Dim the colors of the edges
|
# Dim the colors of the edges
|
||||||
dim_edge_colors_animations = []
|
dim_edge_colors_animations = []
|
||||||
@ -92,12 +97,9 @@ def make_pre_dropout_animation(
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dim_edge_colors_animations.append(
|
dim_edge_colors_animations.append(FadeOut(edge))
|
||||||
FadeOut(edge)
|
|
||||||
)
|
|
||||||
dim_edge_colors_animation = AnimationGroup(
|
dim_edge_colors_animation = AnimationGroup(
|
||||||
*dim_edge_colors_animations,
|
*dim_edge_colors_animations, lag_ratio=0.0
|
||||||
lag_ratio=0.0
|
|
||||||
)
|
)
|
||||||
# Dim the colors of the nodes
|
# Dim the colors of the nodes
|
||||||
dim_nodes_animations = []
|
dim_nodes_animations = []
|
||||||
@ -113,10 +115,7 @@ def make_pre_dropout_animation(
|
|||||||
create_x = Create(x_mark)
|
create_x = Create(x_mark)
|
||||||
dim_nodes_animations.append(create_x)
|
dim_nodes_animations.append(create_x)
|
||||||
|
|
||||||
dim_nodes_animation = AnimationGroup(
|
dim_nodes_animation = AnimationGroup(*dim_nodes_animations, lag_ratio=0.0)
|
||||||
*dim_nodes_animations,
|
|
||||||
lag_ratio=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
animation_group = AnimationGroup(
|
animation_group = AnimationGroup(
|
||||||
dim_edge_colors_animation,
|
dim_edge_colors_animation,
|
||||||
@ -125,6 +124,7 @@ def make_pre_dropout_animation(
|
|||||||
|
|
||||||
return animation_group, x_marks
|
return animation_group, x_marks
|
||||||
|
|
||||||
|
|
||||||
def make_post_dropout_animation(
|
def make_post_dropout_animation(
|
||||||
neural_network,
|
neural_network,
|
||||||
layers_to_nodes_to_drop_out,
|
layers_to_nodes_to_drop_out,
|
||||||
@ -143,8 +143,7 @@ def make_post_dropout_animation(
|
|||||||
layers_to_edges_to_dropout = {}
|
layers_to_edges_to_dropout = {}
|
||||||
for layer in feed_forward_to_feed_forward_layers:
|
for layer in feed_forward_to_feed_forward_layers:
|
||||||
layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out(
|
layers_to_edges_to_dropout[layer], _ = get_edges_to_drop_out(
|
||||||
layer,
|
layer, layers_to_nodes_to_drop_out
|
||||||
layers_to_nodes_to_drop_out
|
|
||||||
)
|
)
|
||||||
# Remove the x marks
|
# Remove the x marks
|
||||||
uncreate_animations = []
|
uncreate_animations = []
|
||||||
@ -152,10 +151,7 @@ def make_post_dropout_animation(
|
|||||||
uncreate_x_mark = Uncreate(x_mark)
|
uncreate_x_mark = Uncreate(x_mark)
|
||||||
uncreate_animations.append(uncreate_x_mark)
|
uncreate_animations.append(uncreate_x_mark)
|
||||||
|
|
||||||
uncreate_x_marks = AnimationGroup(
|
uncreate_x_marks = AnimationGroup(*uncreate_animations, lag_ratio=0.0)
|
||||||
*uncreate_animations,
|
|
||||||
lag_ratio=0.0
|
|
||||||
)
|
|
||||||
# Add the edges back
|
# Add the edges back
|
||||||
create_edge_animations = []
|
create_edge_animations = []
|
||||||
for layer in layers_to_edges_to_dropout.keys():
|
for layer in layers_to_edges_to_dropout.keys():
|
||||||
@ -164,20 +160,12 @@ def make_post_dropout_animation(
|
|||||||
for edge_index, edge in enumerate(edges_to_drop_out):
|
for edge_index, edge in enumerate(edges_to_drop_out):
|
||||||
edge_copy = edge.copy()
|
edge_copy = edge.copy()
|
||||||
edges_to_drop_out[edge_index] = edge_copy
|
edges_to_drop_out[edge_index] = edge_copy
|
||||||
create_edge_animations.append(
|
create_edge_animations.append(FadeIn(edge_copy))
|
||||||
FadeIn(edge_copy)
|
|
||||||
)
|
|
||||||
|
|
||||||
create_edge_animation = AnimationGroup(
|
create_edge_animation = AnimationGroup(*create_edge_animations, lag_ratio=0.0)
|
||||||
*create_edge_animations,
|
|
||||||
lag_ratio=0.0
|
return AnimationGroup(uncreate_x_marks, create_edge_animation, lag_ratio=0.0)
|
||||||
)
|
|
||||||
|
|
||||||
return AnimationGroup(
|
|
||||||
uncreate_x_marks,
|
|
||||||
create_edge_animation,
|
|
||||||
lag_ratio=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_forward_pass_with_dropout_animation(
|
def make_forward_pass_with_dropout_animation(
|
||||||
neural_network,
|
neural_network,
|
||||||
@ -195,26 +183,16 @@ def make_forward_pass_with_dropout_animation(
|
|||||||
)
|
)
|
||||||
# Iterate through network and get feed forward layers
|
# Iterate through network and get feed forward layers
|
||||||
for layer in feed_forward_layers:
|
for layer in feed_forward_layers:
|
||||||
layer_args[layer] = {
|
layer_args[layer] = {"dropout_node_indices": layers_to_nodes_to_drop_out[layer]}
|
||||||
"dropout_node_indices": layers_to_nodes_to_drop_out[layer]
|
|
||||||
}
|
|
||||||
for layer in feed_forward_to_feed_forward_layers:
|
for layer in feed_forward_to_feed_forward_layers:
|
||||||
_, edge_indices = get_edges_to_drop_out(
|
_, edge_indices = get_edges_to_drop_out(layer, layers_to_nodes_to_drop_out)
|
||||||
layer,
|
layer_args[layer] = {"edge_indices_to_dropout": edge_indices}
|
||||||
layers_to_nodes_to_drop_out
|
|
||||||
)
|
return neural_network.make_forward_pass_animation(layer_args=layer_args)
|
||||||
layer_args[layer] = {
|
|
||||||
"edge_indices_to_dropout": edge_indices
|
|
||||||
}
|
|
||||||
|
|
||||||
return neural_network.make_forward_pass_animation(
|
|
||||||
layer_args=layer_args
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_neural_network_dropout_animation(
|
def make_neural_network_dropout_animation(
|
||||||
neural_network,
|
neural_network, dropout_rate=0.5, do_forward_pass=True
|
||||||
dropout_rate=0.5,
|
|
||||||
do_forward_pass=True
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Makes a dropout animation for a given neural network.
|
Makes a dropout animation for a given neural network.
|
||||||
@ -247,26 +225,19 @@ def make_neural_network_dropout_animation(
|
|||||||
layers_to_nodes_to_drop_out[feed_forward_layer] = nodes_to_drop_out
|
layers_to_nodes_to_drop_out[feed_forward_layer] = nodes_to_drop_out
|
||||||
# Make the animation
|
# Make the animation
|
||||||
pre_dropout_animation, x_marks = make_pre_dropout_animation(
|
pre_dropout_animation, x_marks = make_pre_dropout_animation(
|
||||||
neural_network,
|
neural_network, layers_to_nodes_to_drop_out
|
||||||
layers_to_nodes_to_drop_out
|
|
||||||
)
|
)
|
||||||
if do_forward_pass:
|
if do_forward_pass:
|
||||||
forward_pass_animation = make_forward_pass_with_dropout_animation(
|
forward_pass_animation = make_forward_pass_with_dropout_animation(
|
||||||
neural_network,
|
neural_network, layers_to_nodes_to_drop_out
|
||||||
layers_to_nodes_to_drop_out
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
forward_pass_animation = AnimationGroup()
|
forward_pass_animation = AnimationGroup()
|
||||||
|
|
||||||
post_dropout_animation = make_post_dropout_animation(
|
post_dropout_animation = make_post_dropout_animation(
|
||||||
neural_network,
|
neural_network, layers_to_nodes_to_drop_out, x_marks
|
||||||
layers_to_nodes_to_drop_out,
|
|
||||||
x_marks
|
|
||||||
)
|
)
|
||||||
# Combine the animations into one
|
# Combine the animations into one
|
||||||
return Succession(
|
return Succession(
|
||||||
pre_dropout_animation,
|
pre_dropout_animation, forward_pass_animation, post_dropout_animation
|
||||||
forward_pass_animation,
|
|
||||||
post_dropout_animation
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,13 +51,6 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
about_point=self.get_center(),
|
about_point=self.get_center(),
|
||||||
axis=ThreeDLayer.rotation_axis,
|
axis=ThreeDLayer.rotation_axis,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
self.rotate(
|
|
||||||
ThreeDLayer.three_d_y_rotation,
|
|
||||||
about_point=self.get_center(),
|
|
||||||
axis=[0, 1, 0]
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def construct_feature_maps(self):
|
def construct_feature_maps(self):
|
||||||
"""Creates the neural network layer"""
|
"""Creates the neural network layer"""
|
||||||
|
@ -260,6 +260,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
pulse_color=ORANGE,
|
pulse_color=ORANGE,
|
||||||
cell_width=0.2,
|
cell_width=0.2,
|
||||||
show_grid_lines=True,
|
show_grid_lines=True,
|
||||||
|
highlight_color=ORANGE,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -284,6 +285,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
self.line_color = line_color
|
self.line_color = line_color
|
||||||
self.pulse_color = pulse_color
|
self.pulse_color = pulse_color
|
||||||
self.show_grid_lines = show_grid_lines
|
self.show_grid_lines = show_grid_lines
|
||||||
|
self.highlight_color = highlight_color
|
||||||
|
|
||||||
def get_rotated_shift_vectors(self):
|
def get_rotated_shift_vectors(self):
|
||||||
"""
|
"""
|
||||||
@ -344,11 +346,11 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
animations.append(FadeOut(filters))
|
animations.append(FadeOut(filters))
|
||||||
return Succession(*animations, lag_ratio=1.0)
|
return Succession(*animations, lag_ratio=1.0)
|
||||||
|
|
||||||
def animate_filters_one_at_a_time(self):
|
def animate_filters_one_at_a_time(self, highlight_active_feature_map=False):
|
||||||
"""Animates each of the filters one at a time"""
|
"""Animates each of the filters one at a time"""
|
||||||
animations = []
|
animations = []
|
||||||
output_feature_maps = self.output_layer.feature_maps
|
output_feature_maps = self.output_layer.feature_maps
|
||||||
for filter_index in range(len(output_feature_maps)):
|
for feature_map_index in range(len(output_feature_maps)):
|
||||||
# Make filters
|
# Make filters
|
||||||
filters = Filters(
|
filters = Filters(
|
||||||
self.input_layer,
|
self.input_layer,
|
||||||
@ -356,9 +358,28 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
line_color=self.color,
|
line_color=self.color,
|
||||||
cell_width=self.cell_width,
|
cell_width=self.cell_width,
|
||||||
show_grid_lines=self.show_grid_lines,
|
show_grid_lines=self.show_grid_lines,
|
||||||
output_feature_map_to_connect=filter_index, # None means all at once
|
output_feature_map_to_connect=feature_map_index, # None means all at once
|
||||||
)
|
)
|
||||||
animations.append(Create(filters))
|
animations.append(Create(filters))
|
||||||
|
# Highlight the feature map
|
||||||
|
if highlight_active_feature_map:
|
||||||
|
feature_map = output_feature_maps[feature_map_index]
|
||||||
|
original_feature_map_color = feature_map.color
|
||||||
|
# Change the output feature map colors
|
||||||
|
change_color_animations = []
|
||||||
|
change_color_animations.append(
|
||||||
|
ApplyMethod(feature_map.set_color, self.highlight_color)
|
||||||
|
)
|
||||||
|
# Change the input feature map colors
|
||||||
|
input_feature_maps = self.input_layer.feature_maps
|
||||||
|
for input_feature_map in input_feature_maps:
|
||||||
|
change_color_animations.append(
|
||||||
|
ApplyMethod(input_feature_map.set_color, self.highlight_color)
|
||||||
|
)
|
||||||
|
# Combine the animations
|
||||||
|
animations.append(
|
||||||
|
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
||||||
|
)
|
||||||
# Get the rotated shift vectors
|
# Get the rotated shift vectors
|
||||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||||
left_shift = -1 * right_shift
|
left_shift = -1 * right_shift
|
||||||
@ -403,11 +424,36 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
animations.append(shift_animation)
|
animations.append(shift_animation)
|
||||||
# Remove the filters
|
# Remove the filters
|
||||||
animations.append(FadeOut(filters))
|
animations.append(FadeOut(filters))
|
||||||
|
# Un-highlight the feature map
|
||||||
|
if highlight_active_feature_map:
|
||||||
|
feature_map = output_feature_maps[feature_map_index]
|
||||||
|
# Change the output feature map colors
|
||||||
|
change_color_animations = []
|
||||||
|
change_color_animations.append(
|
||||||
|
ApplyMethod(feature_map.set_color, original_feature_map_color)
|
||||||
|
)
|
||||||
|
# Change the input feature map colors
|
||||||
|
input_feature_maps = self.input_layer.feature_maps
|
||||||
|
for input_feature_map in input_feature_maps:
|
||||||
|
change_color_animations.append(
|
||||||
|
ApplyMethod(
|
||||||
|
input_feature_map.set_color, original_feature_map_color
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Combine the animations
|
||||||
|
animations.append(
|
||||||
|
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
||||||
|
)
|
||||||
|
|
||||||
return Succession(*animations, lag_ratio=1.0)
|
return Succession(*animations, lag_ratio=1.0)
|
||||||
|
|
||||||
def make_forward_pass_animation(
|
def make_forward_pass_animation(
|
||||||
self, layer_args={}, all_filters_at_once=False, run_time=10.5, **kwargs
|
self,
|
||||||
|
layer_args={},
|
||||||
|
all_filters_at_once=False,
|
||||||
|
highlight_active_feature_map=False,
|
||||||
|
run_time=10.5,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Forward pass animation from conv2d to conv2d"""
|
"""Forward pass animation from conv2d to conv2d"""
|
||||||
print(f"All filters at once: {all_filters_at_once}")
|
print(f"All filters at once: {all_filters_at_once}")
|
||||||
@ -415,7 +461,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
if all_filters_at_once:
|
if all_filters_at_once:
|
||||||
return self.animate_filters_all_at_once()
|
return self.animate_filters_all_at_once()
|
||||||
else:
|
else:
|
||||||
return self.animate_filters_one_at_a_time()
|
return self.animate_filters_one_at_a_time(
|
||||||
|
highlight_active_feature_map=highlight_active_feature_map
|
||||||
|
)
|
||||||
|
|
||||||
def scale(self, scale_factor, **kwargs):
|
def scale(self, scale_factor, **kwargs):
|
||||||
self.cell_width *= scale_factor
|
self.cell_width *= scale_factor
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
||||||
|
|
||||||
|
|
||||||
class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||||
"""Handles rendering a layer for a neural network"""
|
"""Handles rendering a layer for a neural network"""
|
||||||
|
|
||||||
@ -78,16 +79,10 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
|||||||
# Make highlight animation
|
# Make highlight animation
|
||||||
succession = Succession(
|
succession = Succession(
|
||||||
ApplyMethod(
|
ApplyMethod(
|
||||||
nodes_to_highlight.set_color,
|
nodes_to_highlight.set_color, self.animation_dot_color, run_time=0.25
|
||||||
self.animation_dot_color,
|
|
||||||
run_time=0.25
|
|
||||||
),
|
),
|
||||||
Wait(1.0),
|
Wait(1.0),
|
||||||
ApplyMethod(
|
ApplyMethod(nodes_to_highlight.set_color, self.node_color, run_time=0.25),
|
||||||
nodes_to_highlight.set_color,
|
|
||||||
self.node_color,
|
|
||||||
run_time=0.25
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return succession
|
return succession
|
||||||
@ -97,8 +92,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
|||||||
if "dropout_node_indices" in layer_args:
|
if "dropout_node_indices" in layer_args:
|
||||||
# Drop out certain nodes
|
# Drop out certain nodes
|
||||||
return self.make_dropout_forward_pass_animation(
|
return self.make_dropout_forward_pass_animation(
|
||||||
layer_args=layer_args,
|
layer_args=layer_args, **kwargs
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Make highlight animation
|
# Make highlight animation
|
||||||
|
@ -68,19 +68,20 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
|||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
def make_forward_pass_animation(
|
def make_forward_pass_animation(
|
||||||
self,
|
self, layer_args={}, run_time=1, feed_forward_dropout=0.0, **kwargs
|
||||||
layer_args={},
|
|
||||||
run_time=1,
|
|
||||||
feed_forward_dropout=0.0,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
"""Animation for passing information from one FeedForwardLayer to the next"""
|
"""Animation for passing information from one FeedForwardLayer to the next"""
|
||||||
path_animations = []
|
path_animations = []
|
||||||
dots = []
|
dots = []
|
||||||
for edge_index, edge in enumerate(self.edges):
|
for edge_index, edge in enumerate(self.edges):
|
||||||
if not edge_index in layer_args["edge_indices_to_dropout"]:
|
if (
|
||||||
|
not "edge_indices_to_dropout" in layer_args
|
||||||
|
or not edge_index in layer_args["edge_indices_to_dropout"]
|
||||||
|
):
|
||||||
dot = Dot(
|
dot = Dot(
|
||||||
color=self.animation_dot_color, fill_opacity=1.0, radius=self.dot_radius
|
color=self.animation_dot_color,
|
||||||
|
fill_opacity=1.0,
|
||||||
|
radius=self.dot_radius,
|
||||||
)
|
)
|
||||||
# Add to dots group
|
# Add to dots group
|
||||||
dots.append(dot)
|
dots.append(dot)
|
||||||
|
@ -56,6 +56,10 @@ class ImageLayer(NeuralNetworkLayer):
|
|||||||
"""Override get right"""
|
"""Override get right"""
|
||||||
return self.image_mobject.get_right()
|
return self.image_mobject.get_right()
|
||||||
|
|
||||||
|
def scale(self, scale_factor, **kwargs):
|
||||||
|
"""Scales the image mobject"""
|
||||||
|
self.image_mobject.scale(scale_factor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def width(self):
|
def width(self):
|
||||||
return self.image_mobject.width
|
return self.image_mobject.width
|
||||||
|
@ -27,10 +27,10 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
"""Maps image to convolutional layer"""
|
"""Maps image to convolutional layer"""
|
||||||
# Transform the image from the input layer to the
|
# Transform the image from the input layer to the
|
||||||
num_image_channels = self.input_layer.num_channels
|
num_image_channels = self.input_layer.num_channels
|
||||||
if num_image_channels == 3:
|
if num_image_channels == 1 or num_image_channels == 3: # TODO fix this later
|
||||||
return self.rbg_image_animation()
|
|
||||||
elif num_image_channels == 1:
|
|
||||||
return self.grayscale_image_animation()
|
return self.grayscale_image_animation()
|
||||||
|
elif num_image_channels == 3:
|
||||||
|
return self.rbg_image_animation()
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Unrecognized number of image channels: {num_image_channels}"
|
f"Unrecognized number of image channels: {num_image_channels}"
|
||||||
@ -43,7 +43,6 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
# TODO create image mobjects for each channel and transform
|
# TODO create image mobjects for each channel and transform
|
||||||
# it to the feature maps of the output_layer
|
# it to the feature maps of the output_layer
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
pass
|
|
||||||
|
|
||||||
def grayscale_image_animation(self):
|
def grayscale_image_animation(self):
|
||||||
"""Handles animation for 1 channel image"""
|
"""Handles animation for 1 channel image"""
|
||||||
@ -80,7 +79,7 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
# Scale the max of width or height to the
|
# Scale the max of width or height to the
|
||||||
# width of the feature_map
|
# width of the feature_map
|
||||||
max_width_height = max(image_mobject.width, image_mobject.height)
|
max_width_height = max(image_mobject.width, image_mobject.height)
|
||||||
scale_factor = target_feature_map.rectangle_width / max_width_height
|
scale_factor = target_feature_map.width / max_width_height
|
||||||
scale_image = ApplyMethod(image_mobject.scale, scale_factor, run_time=0.5)
|
scale_image = ApplyMethod(image_mobject.scale, scale_factor, run_time=0.5)
|
||||||
# Move the image
|
# Move the image
|
||||||
move_image = ApplyMethod(image_mobject.move_to, target_feature_map)
|
move_image = ApplyMethod(image_mobject.move_to, target_feature_map)
|
||||||
|
@ -69,10 +69,13 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
|||||||
return super()._create_override()
|
return super()._create_override()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__}(" + \
|
return (
|
||||||
f"input_layer={self.input_layer.__class__.__name__}," + \
|
f"{self.__class__.__name__}("
|
||||||
f"output_layer={self.output_layer.__class__.__name__}," + \
|
+ f"input_layer={self.input_layer.__class__.__name__},"
|
||||||
")"
|
+ f"output_layer={self.output_layer.__class__.__name__},"
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlankConnective(ConnectiveLayer):
|
class BlankConnective(ConnectiveLayer):
|
||||||
"""Connective layer to be used when the given pair of layers is undefined"""
|
"""Connective layer to be used when the given pair of layers is undefined"""
|
||||||
|
@ -22,6 +22,7 @@ from manim_ml.neural_network.neural_network_transformations import (
|
|||||||
RemoveLayer,
|
RemoveLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NeuralNetwork(Group):
|
class NeuralNetwork(Group):
|
||||||
"""Neural Network Visualization Container Class"""
|
"""Neural Network Visualization Container Class"""
|
||||||
|
|
||||||
@ -34,8 +35,8 @@ class NeuralNetwork(Group):
|
|||||||
edge_width=2.5,
|
edge_width=2.5,
|
||||||
dot_radius=0.03,
|
dot_radius=0.03,
|
||||||
title=" ",
|
title=" ",
|
||||||
three_d_phi=-70 * DEGREES,
|
layout="linear",
|
||||||
three_d_theta=-80 * DEGREES,
|
layout_direction="left_to_right",
|
||||||
):
|
):
|
||||||
super(Group, self).__init__()
|
super(Group, self).__init__()
|
||||||
self.input_layers = ListGroup(*input_layers)
|
self.input_layers = ListGroup(*input_layers)
|
||||||
@ -46,13 +47,12 @@ class NeuralNetwork(Group):
|
|||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
self.title_text = title
|
self.title_text = title
|
||||||
self.created = False
|
self.created = False
|
||||||
# Make the layer fixed in frame if its not 3D
|
self.layout = layout
|
||||||
ThreeDLayer.three_d_theta = three_d_theta
|
self.layout_direction = layout_direction
|
||||||
ThreeDLayer.three_d_phi = three_d_phi
|
|
||||||
# TODO take layer_node_count [0, (1, 2), 0]
|
# TODO take layer_node_count [0, (1, 2), 0]
|
||||||
# and make it have explicit distinct subspaces
|
# and make it have explicit distinct subspaces
|
||||||
# Place the layers
|
# Place the layers
|
||||||
self._place_layers()
|
self._place_layers(layout=layout, layout_direction=layout_direction)
|
||||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||||
# Make overhead title
|
# Make overhead title
|
||||||
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
|
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
|
||||||
@ -67,7 +67,7 @@ class NeuralNetwork(Group):
|
|||||||
# Print neural network
|
# Print neural network
|
||||||
print(repr(self))
|
print(repr(self))
|
||||||
|
|
||||||
def _place_layers(self):
|
def _place_layers(self, layout="linear", layout_direction="top_to_bottom"):
|
||||||
"""Creates the neural network"""
|
"""Creates the neural network"""
|
||||||
# TODO implement more sophisticated custom layouts
|
# TODO implement more sophisticated custom layouts
|
||||||
# Default: Linear layout
|
# Default: Linear layout
|
||||||
@ -79,6 +79,7 @@ class NeuralNetwork(Group):
|
|||||||
if isinstance(current_layer, EmbeddingLayer) or isinstance(
|
if isinstance(current_layer, EmbeddingLayer) or isinstance(
|
||||||
previous_layer, EmbeddingLayer
|
previous_layer, EmbeddingLayer
|
||||||
):
|
):
|
||||||
|
if layout_direction == "left_to_right":
|
||||||
shift_vector = np.array(
|
shift_vector = np.array(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
@ -90,15 +91,53 @@ class NeuralNetwork(Group):
|
|||||||
0,
|
0,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
elif layout_direction == "top_to_bottom":
|
||||||
shift_vector = np.array(
|
shift_vector = np.array(
|
||||||
[
|
[
|
||||||
(previous_layer.get_width() / 2 + current_layer.get_width() / 2)
|
0,
|
||||||
|
-(
|
||||||
|
previous_layer.get_width() / 2
|
||||||
|
+ current_layer.get_width() / 2
|
||||||
|
- 0.2
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Unrecognized layout direction: {layout_direction}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if layout_direction == "left_to_right":
|
||||||
|
shift_vector = np.array(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
previous_layer.get_width() / 2
|
||||||
|
+ current_layer.get_width() / 2
|
||||||
|
)
|
||||||
+ self.layer_spacing,
|
+ self.layer_spacing,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
elif layout_direction == "top_to_bottom":
|
||||||
|
shift_vector = np.array(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
-(
|
||||||
|
(
|
||||||
|
previous_layer.get_width() / 2
|
||||||
|
+ current_layer.get_width() / 2
|
||||||
|
)
|
||||||
|
+ self.layer_spacing
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Unrecognized layout direction: {layout_direction}"
|
||||||
|
)
|
||||||
current_layer.shift(shift_vector)
|
current_layer.shift(shift_vector)
|
||||||
|
|
||||||
def _construct_connective_layers(self):
|
def _construct_connective_layers(self):
|
||||||
@ -178,7 +217,7 @@ class NeuralNetwork(Group):
|
|||||||
current_layer_args = {
|
current_layer_args = {
|
||||||
**before_layer_args,
|
**before_layer_args,
|
||||||
**current_layer_args,
|
**current_layer_args,
|
||||||
**after_layer_args
|
**after_layer_args,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
current_layer_args = {}
|
current_layer_args = {}
|
||||||
@ -229,14 +268,18 @@ class NeuralNetwork(Group):
|
|||||||
"""Overriden scale"""
|
"""Overriden scale"""
|
||||||
for layer in self.all_layers:
|
for layer in self.all_layers:
|
||||||
layer.scale(scale_factor, **kwargs)
|
layer.scale(scale_factor, **kwargs)
|
||||||
# super().scale(scale_factor)
|
# Place layers with scaled spacing
|
||||||
|
self.layer_spacing *= scale_factor
|
||||||
|
self._place_layers(layout=self.layout, layout_direction=self.layout_direction)
|
||||||
|
|
||||||
def filter_layers(self, function):
|
def filter_layers(self, function):
|
||||||
"""Filters layers of the network given function"""
|
"""Filters layers of the network given function"""
|
||||||
layers_to_return = []
|
layers_to_return = []
|
||||||
for layer in self.all_layers:
|
for layer in self.all_layers:
|
||||||
func_out = function(layer)
|
func_out = function(layer)
|
||||||
assert isinstance(func_out, bool), "Filter layers function returned a non-boolean type."
|
assert isinstance(
|
||||||
|
func_out, bool
|
||||||
|
), "Filter layers function returned a non-boolean type."
|
||||||
if func_out:
|
if func_out:
|
||||||
layers_to_return.append(layer)
|
layers_to_return.append(layer)
|
||||||
|
|
||||||
@ -257,6 +300,7 @@ class NeuralNetwork(Group):
|
|||||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||||
return string_repr
|
return string_repr
|
||||||
|
|
||||||
|
|
||||||
class FeedForwardNeuralNetwork(NeuralNetwork):
|
class FeedForwardNeuralNetwork(NeuralNetwork):
|
||||||
"""NeuralNetwork with just feed forward layers"""
|
"""NeuralNetwork with just feed forward layers"""
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network.layers.util import get_connective_layer
|
from manim_ml.neural_network.layers.util import get_connective_layer
|
||||||
|
|
||||||
|
|
||||||
class RemoveLayer(AnimationGroup):
|
class RemoveLayer(AnimationGroup):
|
||||||
"""
|
"""
|
||||||
Animation for removing a layer from a neural network.
|
Animation for removing a layer from a neural network.
|
||||||
|
@ -1,27 +1,31 @@
|
|||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.decision_tree.decision_tree import DecisionTreeDiagram, DecisionTreeSurface, IrisDatasetPlot
|
from manim_ml.decision_tree.decision_tree import (
|
||||||
|
DecisionTreeDiagram,
|
||||||
|
DecisionTreeSurface,
|
||||||
|
IrisDatasetPlot,
|
||||||
|
)
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from sklearn import datasets
|
from sklearn import datasets
|
||||||
import sklearn
|
import sklearn
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
def learn_iris_decision_tree(iris):
|
def learn_iris_decision_tree(iris):
|
||||||
decision_tree = DecisionTreeClassifier(
|
decision_tree = DecisionTreeClassifier(
|
||||||
random_state=1,
|
random_state=1, max_depth=3, max_leaf_nodes=6
|
||||||
max_depth=3,
|
|
||||||
max_leaf_nodes=6
|
|
||||||
)
|
)
|
||||||
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
||||||
# output the decisioin tree in some format
|
# output the decisioin tree in some format
|
||||||
return decision_tree
|
return decision_tree
|
||||||
|
|
||||||
|
|
||||||
def make_sklearn_tree(dataset, max_tree_depth=3):
|
def make_sklearn_tree(dataset, max_tree_depth=3):
|
||||||
tree = learn_iris_decision_tree(dataset)
|
tree = learn_iris_decision_tree(dataset)
|
||||||
feature_names = dataset.feature_names[0:2]
|
feature_names = dataset.feature_names[0:2]
|
||||||
return tree, tree.tree_
|
return tree, tree.tree_
|
||||||
|
|
||||||
class DecisionTreeScene(Scene):
|
|
||||||
|
|
||||||
|
class DecisionTreeScene(Scene):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
"""Makes a decision tree object"""
|
"""Makes a decision tree object"""
|
||||||
iris_dataset = datasets.load_iris()
|
iris_dataset = datasets.load_iris()
|
||||||
@ -36,15 +40,8 @@ class DecisionTreeScene(Scene):
|
|||||||
"images/iris_dataset/VeriscolorFlower.jpeg",
|
"images/iris_dataset/VeriscolorFlower.jpeg",
|
||||||
"images/iris_dataset/VirginicaFlower.jpeg",
|
"images/iris_dataset/VirginicaFlower.jpeg",
|
||||||
],
|
],
|
||||||
class_names=[
|
class_names=["Setosa", "Veriscolor", "Virginica"],
|
||||||
"Setosa",
|
feature_names=["Sepal Length", "Sepal Width"],
|
||||||
"Veriscolor",
|
|
||||||
"Virginica"
|
|
||||||
],
|
|
||||||
feature_names=[
|
|
||||||
"Sepal Length",
|
|
||||||
"Sepal Width"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
decision_tree.move_to(ORIGIN)
|
decision_tree.move_to(ORIGIN)
|
||||||
create_decision_tree = Create(decision_tree, traversal_order="bfs")
|
create_decision_tree = Create(decision_tree, traversal_order="bfs")
|
||||||
@ -53,7 +50,6 @@ class DecisionTreeScene(Scene):
|
|||||||
|
|
||||||
|
|
||||||
class SurfacePlot(Scene):
|
class SurfacePlot(Scene):
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
iris_dataset = datasets.load_iris()
|
iris_dataset = datasets.load_iris()
|
||||||
iris_dataset_plot = IrisDatasetPlot(iris_dataset)
|
iris_dataset_plot = IrisDatasetPlot(iris_dataset)
|
||||||
@ -63,13 +59,7 @@ class SurfacePlot(Scene):
|
|||||||
# make the decision tree classifier
|
# make the decision tree classifier
|
||||||
decision_tree_classifier, sklearn_tree = make_sklearn_tree(iris_dataset)
|
decision_tree_classifier, sklearn_tree = make_sklearn_tree(iris_dataset)
|
||||||
decision_tree_surface = DecisionTreeSurface(
|
decision_tree_surface = DecisionTreeSurface(
|
||||||
decision_tree_classifier,
|
decision_tree_classifier, iris_dataset.data, iris_dataset_plot.axes_group[0]
|
||||||
iris_dataset.data,
|
|
||||||
iris_dataset_plot.axes_group[0]
|
|
||||||
)
|
|
||||||
self.play(
|
|
||||||
Create(
|
|
||||||
decision_tree_surface
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
self.play(Create(decision_tree_surface))
|
||||||
self.wait(1)
|
self.wait(1)
|
@ -1,5 +1,7 @@
|
|||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network.animations.dropout import make_neural_network_dropout_animation
|
from manim_ml.neural_network.animations.dropout import (
|
||||||
|
make_neural_network_dropout_animation,
|
||||||
|
)
|
||||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
from manim_ml.neural_network.layers.image import ImageLayer
|
from manim_ml.neural_network.layers.image import ImageLayer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -11,6 +13,7 @@ config.pixel_width = 1900
|
|||||||
config.frame_height = 5.0
|
config.frame_height = 5.0
|
||||||
config.frame_width = 5.0
|
config.frame_width = 5.0
|
||||||
|
|
||||||
|
|
||||||
def make_code_snippet():
|
def make_code_snippet():
|
||||||
code_str = """
|
code_str = """
|
||||||
nn = NeuralNetwork([
|
nn = NeuralNetwork([
|
||||||
@ -40,17 +43,19 @@ def make_code_snippet():
|
|||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
class DropoutNeuralNetworkScene(Scene):
|
class DropoutNeuralNetworkScene(Scene):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
# Make nn
|
# Make nn
|
||||||
nn = NeuralNetwork([
|
nn = NeuralNetwork(
|
||||||
|
[
|
||||||
FeedForwardLayer(3, rectangle_color=BLUE),
|
FeedForwardLayer(3, rectangle_color=BLUE),
|
||||||
FeedForwardLayer(5, rectangle_color=BLUE),
|
FeedForwardLayer(5, rectangle_color=BLUE),
|
||||||
FeedForwardLayer(3, rectangle_color=BLUE),
|
FeedForwardLayer(3, rectangle_color=BLUE),
|
||||||
FeedForwardLayer(5, rectangle_color=BLUE),
|
FeedForwardLayer(5, rectangle_color=BLUE),
|
||||||
FeedForwardLayer(4, rectangle_color=BLUE),
|
FeedForwardLayer(4, rectangle_color=BLUE),
|
||||||
],
|
],
|
||||||
layer_spacing=0.4
|
layer_spacing=0.4,
|
||||||
)
|
)
|
||||||
# Center the nn
|
# Center the nn
|
||||||
nn.move_to(ORIGIN)
|
nn.move_to(ORIGIN)
|
||||||
@ -63,13 +68,12 @@ class DropoutNeuralNetworkScene(Scene):
|
|||||||
# Play animation
|
# Play animation
|
||||||
self.play(
|
self.play(
|
||||||
make_neural_network_dropout_animation(
|
make_neural_network_dropout_animation(
|
||||||
nn,
|
nn, dropout_rate=0.25, do_forward_pass=True
|
||||||
dropout_rate=0.25,
|
|
||||||
do_forward_pass=True
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.wait(1)
|
self.wait(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""Render all scenes"""
|
"""Render all scenes"""
|
||||||
dropout_nn_scene = DropoutNeuralNetworkScene()
|
dropout_nn_scene = DropoutNeuralNetworkScene()
|
||||||
|
44
tests/test_nn_scale.py
Normal file
44
tests/test_nn_scale.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from manim import *
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.image import ImageLayer
|
||||||
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||||
|
|
||||||
|
# Make the specific scene
|
||||||
|
config.pixel_height = 1200
|
||||||
|
config.pixel_width = 1900
|
||||||
|
config.frame_height = 6.0
|
||||||
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedScene(ThreeDScene):
|
||||||
|
def construct(self):
|
||||||
|
image = Image.open("../assets/mnist/digit.jpeg")
|
||||||
|
numpy_image = np.asarray(image)
|
||||||
|
# Make nn
|
||||||
|
nn = NeuralNetwork(
|
||||||
|
[
|
||||||
|
ImageLayer(numpy_image, height=1.5),
|
||||||
|
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
|
||||||
|
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
],
|
||||||
|
layer_spacing=0.25,
|
||||||
|
)
|
||||||
|
# Center the nn
|
||||||
|
nn.move_to(ORIGIN)
|
||||||
|
nn.scale(1.3)
|
||||||
|
self.add(nn)
|
||||||
|
"""
|
||||||
|
self.play(
|
||||||
|
FadeIn(nn)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
# Play animation
|
||||||
|
forward_pass = nn.make_forward_pass_animation(
|
||||||
|
corner_pulses=False, all_filters_at_once=False, highlight_filters=True
|
||||||
|
)
|
||||||
|
self.wait(1)
|
||||||
|
self.play(forward_pass)
|
Reference in New Issue
Block a user