Files
ManimML/manim_ml/utils/mobjects/connections.py

179 lines
6.7 KiB
Python

import numpy as np
from manim import *
class NetworkConnection(VGroup):
"""
This class allows for creating connections
between locations in a network
"""
direction_vector_map = {"up": UP, "down": DOWN, "left": LEFT, "right": RIGHT}
def __init__(
self,
start_mobject,
end_mobject,
arc_direction="straight",
buffer=0.0,
arc_distance=0.2,
stroke_width=2.0,
color=WHITE,
active_color=ORANGE,
):
"""Creates an arrow with right angles in it connecting
two mobjects.
Parameters
----------
start_mobject : Mobject
Mobject where the start of the connection is from
end_mobject : Mobject
Mobject where the end of the connection goes to
arc_direction : str, optional
direction that the connection arcs, by default "straight"
buffer : float, optional
amount of space between the connection and mobjects at the end
arc_distance : float, optional
Distance from start and end mobject that the arc bends
stroke_width : float, optional
Stroke width of the connection
color : [float], optional
Color of the connection
active_color : [float], optional
Color of active animations for this mobject
"""
super().__init__()
assert arc_direction in ["straight", "up", "down", "left", "right"]
self.start_mobject = start_mobject
self.end_mobject = end_mobject
self.arc_direction = arc_direction
self.buffer = buffer
self.arc_distance = arc_distance
self.stroke_width = stroke_width
self.color = color
self.active_color = active_color
self.make_mobjects()
def make_mobjects(self):
"""Makes the submobjects"""
if self.start_mobject.get_center()[0] < self.end_mobject.get_center()[0]:
left_mobject = self.start_mobject
right_mobject = self.end_mobject
else:
right_mobject = self.start_mobject
left_mobject = self.end_mobject
if self.arc_direction == "straight":
# Make an arrow
self.straight_arrow = Arrow(
start=left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
end=right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0]),
color=WHITE,
fill_color=WHITE,
stroke_opacity=1.0,
buff=0.0,
)
self.add(self.straight_arrow)
else:
# Figure out the direction of the arc
direction_vector = NetworkConnection.direction_vector_map[
self.arc_direction
]
# Based on the position of the start and end layer, and direction
# figure out how large to make each line
# Whichever mobject has a critical point the farthest
# distance in the direction_vector direction we will use that end
left_mobject_critical_point = left_mobject.get_critical_point(direction_vector)
right_mobject_critical_point = right_mobject.get_critical_point(direction_vector)
# Take the dot product of each
# These dot products correspond to the orthogonal projection
# onto the direction vectors
left_dot_product = np.dot(
left_mobject_critical_point,
direction_vector
)
right_dot_product = np.dot(
right_mobject_critical_point,
direction_vector
)
extra_distance = abs(left_dot_product - right_dot_product)
# The difference between the dot products
if left_dot_product < right_dot_product:
right_is_farthest = False
else:
right_is_farthest = True
# Make the start arc piece
start_line_start = left_mobject.get_critical_point(direction_vector)
start_line_start += direction_vector * self.buffer
start_line_end = start_line_start + direction_vector * self.arc_distance
if not right_is_farthest:
start_line_end = start_line_end + direction_vector * extra_distance
self.start_line = Line(
start_line_start,
start_line_end,
color=self.color,
stroke_width=self.stroke_width,
)
# Make the end arc piece with an arrow
end_line_end = right_mobject.get_critical_point(direction_vector)
end_line_end += direction_vector * self.buffer
end_line_start = end_line_end + direction_vector * self.arc_distance
if right_is_farthest:
end_line_start = end_line_start + direction_vector * extra_distance
self.end_arrow = Arrow(
start=end_line_start,
end=end_line_end,
color=WHITE,
fill_color=WHITE,
stroke_opacity=1.0,
buff=0.0,
)
# Make the middle arc piece
self.middle_line = Line(
start_line_end,
end_line_start,
color=self.color,
stroke_width=self.stroke_width,
)
# Add the mobjects
self.add(
self.start_line,
self.middle_line,
self.end_arrow,
)
@override_animation(ShowPassingFlash)
def _override_passing_flash(self, run_time=1.0, time_width=0.2):
"""Passing flash animation"""
if self.arc_direction == "straight":
return ShowPassingFlash(
self.straight_arrow.copy().set_color(self.active_color),
time_width=time_width,
)
else:
# Animate the start line
start_line_animation = ShowPassingFlash(
self.start_line.copy().set_color(self.active_color),
time_width=time_width,
)
# Animate the middle line
middle_line_animation = ShowPassingFlash(
self.middle_line.copy().set_color(self.active_color),
time_width=time_width,
)
# Animate the end line
end_line_animation = ShowPassingFlash(
self.end_arrow.copy().set_color(self.active_color),
time_width=time_width,
)
return AnimationGroup(
start_line_animation,
middle_line_animation,
end_line_animation,
lag_ratio=1.0,
run_time=run_time,
)