mirror of
https://github.com/3b1b/manim.git
synced 2025-08-03 04:04:36 +08:00
Some small performance improvements to VMobject
This commit is contained in:
@ -13,7 +13,7 @@ from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points
|
||||
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
|
||||
from manimlib.utils.bezier import interpolate
|
||||
from manimlib.utils.bezier import integer_interpolate
|
||||
from manimlib.utils.bezier import partial_bezier_points
|
||||
from manimlib.utils.bezier import partial_quadratic_bezier_points
|
||||
from manimlib.utils.color import color_to_rgba
|
||||
from manimlib.utils.color import rgb_to_hex
|
||||
from manimlib.utils.iterables import make_even
|
||||
@ -159,24 +159,36 @@ class VMobject(Mobject):
|
||||
def set_style(self,
|
||||
fill_color=None,
|
||||
fill_opacity=None,
|
||||
fill_rgbas=None,
|
||||
stroke_color=None,
|
||||
stroke_width=None,
|
||||
stroke_opacity=None,
|
||||
stroke_rgbas=None,
|
||||
stroke_width=None,
|
||||
gloss=None,
|
||||
shadow=None,
|
||||
background_image_file=None,
|
||||
family=True):
|
||||
self.set_fill(
|
||||
color=fill_color,
|
||||
opacity=fill_opacity,
|
||||
family=family
|
||||
)
|
||||
self.set_stroke(
|
||||
color=stroke_color,
|
||||
width=stroke_width,
|
||||
opacity=stroke_opacity,
|
||||
family=family,
|
||||
)
|
||||
if fill_rgbas is not None:
|
||||
self.fill_rgbas = np.array(fill_rgbas)
|
||||
else:
|
||||
self.set_fill(
|
||||
color=fill_color,
|
||||
opacity=fill_opacity,
|
||||
family=family
|
||||
)
|
||||
|
||||
if stroke_rgbas is not None:
|
||||
self.stroke_rgbas = np.array(stroke_rgbas)
|
||||
if stroke_width is not None:
|
||||
self.stroke_width = np.array(listify(stroke_width))
|
||||
else:
|
||||
self.set_stroke(
|
||||
color=stroke_color,
|
||||
width=stroke_width,
|
||||
opacity=stroke_opacity,
|
||||
family=family,
|
||||
)
|
||||
|
||||
if gloss is not None:
|
||||
self.set_gloss(gloss, family=family)
|
||||
if shadow is not None:
|
||||
@ -187,17 +199,24 @@ class VMobject(Mobject):
|
||||
|
||||
def get_style(self):
|
||||
return {
|
||||
"fill_color": self.get_fill_colors(),
|
||||
"fill_opacity": self.get_fill_opacities(),
|
||||
"stroke_color": self.get_stroke_colors(),
|
||||
"stroke_width": self.get_stroke_width(),
|
||||
"stroke_opacity": self.get_stroke_opacity(),
|
||||
"fill_rgbas": self.get_fill_rgbas(),
|
||||
"stroke_rgbas": self.get_stroke_rgbas(),
|
||||
"stroke_width": self.stroke_width,
|
||||
"gloss": self.get_gloss(),
|
||||
"shadow": self.get_shadow(),
|
||||
"background_image_file": self.get_background_image_file(),
|
||||
}
|
||||
|
||||
def match_style(self, vmobject, family=True):
|
||||
self.set_style(**vmobject.get_style(), family=False)
|
||||
for name, value in vmobject.get_style().items():
|
||||
if isinstance(value, np.ndarray):
|
||||
curr = getattr(self, name)
|
||||
if curr.size == value.size:
|
||||
curr[:] = value[:]
|
||||
else:
|
||||
setattr(self, name, np.array(value))
|
||||
else:
|
||||
setattr(self, name, value)
|
||||
|
||||
if family:
|
||||
# Does its best to match up submobject lists, and
|
||||
@ -294,6 +313,23 @@ class VMobject(Mobject):
|
||||
return self.get_stroke_color()
|
||||
return self.get_fill_color()
|
||||
|
||||
def has_stroke(self):
|
||||
if len(self.stroke_width) == 1:
|
||||
if self.stroke_width == 0:
|
||||
return False
|
||||
elif not self.stroke_width.any():
|
||||
return False
|
||||
alphas = self.stroke_rgbas[:, 3]
|
||||
if len(alphas) == 1:
|
||||
return alphas[0] > 0
|
||||
return alphas.any()
|
||||
|
||||
def has_fill(self):
|
||||
alphas = self.fill_rgbas[:, 3]
|
||||
if len(alphas) == 1:
|
||||
return alphas[0] > 0
|
||||
return alphas.any()
|
||||
|
||||
# TODO, this currently does nothing
|
||||
def color_using_background_image(self, background_image_file):
|
||||
self.background_image_file = background_image_file
|
||||
@ -450,7 +486,7 @@ class VMobject(Mobject):
|
||||
n = int(np.ceil(angle / angle_threshold))
|
||||
alphas = np.linspace(0, 1, n + 1)
|
||||
new_points.extend([
|
||||
partial_bezier_points(tup, a1, a2)
|
||||
partial_quadratic_bezier_points(tup, a1, a2)
|
||||
for a1, a2 in zip(alphas, alphas[1:])
|
||||
])
|
||||
else:
|
||||
@ -535,10 +571,10 @@ class VMobject(Mobject):
|
||||
nppc = self.n_points_per_curve
|
||||
remainder = len(points) % nppc
|
||||
points = points[:len(points) - remainder]
|
||||
return np.array([
|
||||
return [
|
||||
points[i:i + nppc]
|
||||
for i in range(0, len(points), nppc)
|
||||
])
|
||||
]
|
||||
|
||||
def get_bezier_tuples(self):
|
||||
return self.get_bezier_tuples_from_points(self.get_points())
|
||||
@ -728,14 +764,15 @@ class VMobject(Mobject):
|
||||
vmobject.set_points(np.vstack(new_subpaths2))
|
||||
return self
|
||||
|
||||
def insert_n_curves(self, n):
|
||||
new_points = self.insert_n_curves_to_point_list(n, self.get_points())
|
||||
|
||||
# TODO, this should happen in insert_n_curves_to_point_list
|
||||
if self.has_new_path_started():
|
||||
new_points = np.vstack([new_points, self.get_last_point()])
|
||||
|
||||
self.set_points(new_points)
|
||||
def insert_n_curves(self, n, family=True):
|
||||
mobs = self.get_family() if family else [self]
|
||||
for mob in mobs:
|
||||
if mob.get_num_curves() > 0:
|
||||
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
|
||||
# TODO, this should happen in insert_n_curves_to_point_list
|
||||
if mob.has_new_path_started():
|
||||
new_points = np.vstack([new_points, mob.get_last_point()])
|
||||
mob.set_points(new_points)
|
||||
return self
|
||||
|
||||
def insert_n_curves_to_point_list(self, n, points):
|
||||
@ -768,7 +805,7 @@ class VMobject(Mobject):
|
||||
# smaller quadratic curves
|
||||
alphas = np.linspace(0, 1, n_inserts + 2)
|
||||
for a1, a2 in zip(alphas, alphas[1:]):
|
||||
new_points += partial_bezier_points(group, a1, a2)
|
||||
new_points += partial_quadratic_bezier_points(group, a1, a2)
|
||||
return np.vstack(new_points)
|
||||
|
||||
def align_rgbas(self, vmobject):
|
||||
@ -789,24 +826,21 @@ class VMobject(Mobject):
|
||||
"fill_rgbas",
|
||||
"stroke_rgbas",
|
||||
"stroke_width",
|
||||
# "sheen_direction",
|
||||
# "sheen_factor",
|
||||
]
|
||||
for attr in attrs:
|
||||
arr = getattr(self, attr)
|
||||
m1a = getattr(mobject1, attr)
|
||||
m2a = getattr(mobject2, attr)
|
||||
setattr(self, attr, interpolate(m1a, m2a, alpha))
|
||||
arr[:] = interpolate(m1a, m2a, alpha)
|
||||
|
||||
# TODO, somehow do this using stroke_width changes
|
||||
# so as to not have to change the point list
|
||||
def pointwise_become_partial(self, vmobject, a, b):
|
||||
assert(isinstance(vmobject, VMobject))
|
||||
assert(len(self.points) >= len(vmobject.points))
|
||||
if a <= 0 and b >= 1:
|
||||
self.points[:] = vmobject.points
|
||||
self.points[:] = vmobject.points[:]
|
||||
return self
|
||||
bezier_tuple = vmobject.get_bezier_tuples()
|
||||
num_curves = len(bezier_tuple)
|
||||
num_curves = self.get_num_curves()
|
||||
nppc = self.n_points_per_curve
|
||||
|
||||
# Partial curve includes three portions:
|
||||
# - A middle section, which matches the curve exactly
|
||||
@ -815,27 +849,28 @@ class VMobject(Mobject):
|
||||
|
||||
lower_index, lower_residue = integer_interpolate(0, num_curves, a)
|
||||
upper_index, upper_residue = integer_interpolate(0, num_curves, b)
|
||||
i1 = nppc * lower_index
|
||||
i2 = nppc * (lower_index + 1)
|
||||
i3 = nppc * upper_index
|
||||
i4 = nppc * (upper_index + 1)
|
||||
|
||||
new_point_list = []
|
||||
if num_curves == 0:
|
||||
self.points[:] = 0
|
||||
return self
|
||||
if lower_index == upper_index:
|
||||
new_point_list.append(partial_bezier_points(
|
||||
bezier_tuple[lower_index], lower_residue, upper_residue
|
||||
))
|
||||
tup = partial_quadratic_bezier_points(vmobject.points[i1:i2], lower_residue, upper_residue)
|
||||
self.points[:i1] = tup[0]
|
||||
self.points[i1:i4] = tup
|
||||
self.points[i4:] = tup[2]
|
||||
self.points[nppc:] = self.points[nppc - 1]
|
||||
else:
|
||||
new_point_list.append(partial_bezier_points(
|
||||
bezier_tuple[lower_index], lower_residue, 1
|
||||
))
|
||||
for tup in bezier_tuple[lower_index + 1:upper_index]:
|
||||
new_point_list.append(tup)
|
||||
new_point_list.append(partial_bezier_points(
|
||||
bezier_tuple[upper_index], 0, upper_residue
|
||||
))
|
||||
new_points = np.vstack(new_point_list)
|
||||
self.points[:len(new_points)] = new_points
|
||||
self.points[len(new_points):] = new_points[-1]
|
||||
low_tup = partial_quadratic_bezier_points(vmobject.points[i1:i2], lower_residue, 1)
|
||||
high_tup = partial_quadratic_bezier_points(vmobject.points[i3:i4], 0, upper_residue)
|
||||
self.points[0:i1] = low_tup[0]
|
||||
self.points[i1:i2] = low_tup
|
||||
self.points[i2:i3] = vmobject.points[i2:i3]
|
||||
self.points[i3:i4] = high_tup
|
||||
self.points[i4:] = high_tup[2]
|
||||
return self
|
||||
|
||||
def get_subcurve(self, a, b):
|
||||
@ -881,14 +916,10 @@ class VMobject(Mobject):
|
||||
stroke_data = []
|
||||
fill_data = []
|
||||
for submob in self.family_members_with_points():
|
||||
stroke_width = submob.get_stroke_width()
|
||||
stroke_opacity = submob.get_stroke_opacity()
|
||||
fill_opacity = submob.get_fill_opacity()
|
||||
|
||||
if fill_opacity > 0:
|
||||
fill_data.append(submob.get_fill_shader_data().tobytes())
|
||||
|
||||
if stroke_width > 0 and stroke_opacity > 0:
|
||||
if submob.has_fill():
|
||||
data = submob.get_fill_shader_data().tobytes()
|
||||
fill_data.append(data)
|
||||
if submob.has_stroke():
|
||||
if submob.draw_stroke_behind_fill:
|
||||
data = back_stroke_data
|
||||
else:
|
||||
@ -1061,8 +1092,7 @@ class VectorizedPoint(VMobject, Point):
|
||||
class CurvesAsSubmobjects(VGroup):
|
||||
def __init__(self, vmobject, **kwargs):
|
||||
VGroup.__init__(self, **kwargs)
|
||||
tuples = vmobject.get_bezier_tuples()
|
||||
for tup in tuples:
|
||||
for tup in vmobject.get_bezier_tuples():
|
||||
part = VMobject()
|
||||
part.set_points(tup)
|
||||
part.match_style(vmobject)
|
||||
|
@ -1,11 +1,8 @@
|
||||
from scipy import linalg
|
||||
import numpy as np
|
||||
|
||||
from manimlib.constants import PI
|
||||
from manimlib.utils.simple_functions import choose
|
||||
from manimlib.utils.space_ops import rotate_vector
|
||||
from manimlib.utils.space_ops import find_intersection
|
||||
from manimlib.utils.space_ops import cross
|
||||
from manimlib.utils.space_ops import cross2d
|
||||
|
||||
CLOSED_THRESHOLD = 0.001
|
||||
@ -13,10 +10,14 @@ CLOSED_THRESHOLD = 0.001
|
||||
|
||||
def bezier(points):
|
||||
n = len(points) - 1
|
||||
return lambda t: sum([
|
||||
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
|
||||
for k, point in enumerate(points)
|
||||
])
|
||||
|
||||
def result(t):
|
||||
return sum([
|
||||
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
|
||||
for k, point in enumerate(points)
|
||||
])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def partial_bezier_points(points, a, b):
|
||||
@ -43,6 +44,20 @@ def partial_bezier_points(points, a, b):
|
||||
]
|
||||
|
||||
|
||||
# Shortened version of partial_bezier_points just for quadratics,
|
||||
# since this is called a fair amount
|
||||
def partial_quadratic_bezier_points(points, a, b):
|
||||
def curve(t):
|
||||
return points[0] * (1 - t) * (1 - t) + 2 * points[1] * t * (1 - t) + points[2] * t * t
|
||||
# bezier(points)
|
||||
h0 = curve(a) if a > 0 else points[0]
|
||||
h2 = curve(b) if b < 1 else points[2]
|
||||
h1_prime = (1 - a) * points[1] + a * points[2]
|
||||
end_prop = (b - a) / (1. - a)
|
||||
h1 = (1 - end_prop) * h0 + end_prop * h1_prime
|
||||
return [h0, h1, h2]
|
||||
|
||||
|
||||
# Linear interpolation variants
|
||||
|
||||
def interpolate(start, end, alpha):
|
||||
|
Reference in New Issue
Block a user