Some small performance improvements to VMobject

This commit is contained in:
Grant Sanderson
2020-06-27 12:10:22 -07:00
parent 10c6bfe3ad
commit 54bde86c7b
2 changed files with 116 additions and 71 deletions

View File

@ -13,7 +13,7 @@ from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import partial_bezier_points
from manimlib.utils.bezier import partial_quadratic_bezier_points
from manimlib.utils.color import color_to_rgba
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.iterables import make_even
@ -159,24 +159,36 @@ class VMobject(Mobject):
def set_style(self,
fill_color=None,
fill_opacity=None,
fill_rgbas=None,
stroke_color=None,
stroke_width=None,
stroke_opacity=None,
stroke_rgbas=None,
stroke_width=None,
gloss=None,
shadow=None,
background_image_file=None,
family=True):
self.set_fill(
color=fill_color,
opacity=fill_opacity,
family=family
)
self.set_stroke(
color=stroke_color,
width=stroke_width,
opacity=stroke_opacity,
family=family,
)
if fill_rgbas is not None:
self.fill_rgbas = np.array(fill_rgbas)
else:
self.set_fill(
color=fill_color,
opacity=fill_opacity,
family=family
)
if stroke_rgbas is not None:
self.stroke_rgbas = np.array(stroke_rgbas)
if stroke_width is not None:
self.stroke_width = np.array(listify(stroke_width))
else:
self.set_stroke(
color=stroke_color,
width=stroke_width,
opacity=stroke_opacity,
family=family,
)
if gloss is not None:
self.set_gloss(gloss, family=family)
if shadow is not None:
@ -187,17 +199,24 @@ class VMobject(Mobject):
def get_style(self):
return {
"fill_color": self.get_fill_colors(),
"fill_opacity": self.get_fill_opacities(),
"stroke_color": self.get_stroke_colors(),
"stroke_width": self.get_stroke_width(),
"stroke_opacity": self.get_stroke_opacity(),
"fill_rgbas": self.get_fill_rgbas(),
"stroke_rgbas": self.get_stroke_rgbas(),
"stroke_width": self.stroke_width,
"gloss": self.get_gloss(),
"shadow": self.get_shadow(),
"background_image_file": self.get_background_image_file(),
}
def match_style(self, vmobject, family=True):
self.set_style(**vmobject.get_style(), family=False)
for name, value in vmobject.get_style().items():
if isinstance(value, np.ndarray):
curr = getattr(self, name)
if curr.size == value.size:
curr[:] = value[:]
else:
setattr(self, name, np.array(value))
else:
setattr(self, name, value)
if family:
# Does its best to match up submobject lists, and
@ -294,6 +313,23 @@ class VMobject(Mobject):
return self.get_stroke_color()
return self.get_fill_color()
def has_stroke(self):
if len(self.stroke_width) == 1:
if self.stroke_width == 0:
return False
elif not self.stroke_width.any():
return False
alphas = self.stroke_rgbas[:, 3]
if len(alphas) == 1:
return alphas[0] > 0
return alphas.any()
def has_fill(self):
alphas = self.fill_rgbas[:, 3]
if len(alphas) == 1:
return alphas[0] > 0
return alphas.any()
# TODO, this currently does nothing
def color_using_background_image(self, background_image_file):
self.background_image_file = background_image_file
@ -450,7 +486,7 @@ class VMobject(Mobject):
n = int(np.ceil(angle / angle_threshold))
alphas = np.linspace(0, 1, n + 1)
new_points.extend([
partial_bezier_points(tup, a1, a2)
partial_quadratic_bezier_points(tup, a1, a2)
for a1, a2 in zip(alphas, alphas[1:])
])
else:
@ -535,10 +571,10 @@ class VMobject(Mobject):
nppc = self.n_points_per_curve
remainder = len(points) % nppc
points = points[:len(points) - remainder]
return np.array([
return [
points[i:i + nppc]
for i in range(0, len(points), nppc)
])
]
def get_bezier_tuples(self):
return self.get_bezier_tuples_from_points(self.get_points())
@ -728,14 +764,15 @@ class VMobject(Mobject):
vmobject.set_points(np.vstack(new_subpaths2))
return self
def insert_n_curves(self, n):
new_points = self.insert_n_curves_to_point_list(n, self.get_points())
# TODO, this should happen in insert_n_curves_to_point_list
if self.has_new_path_started():
new_points = np.vstack([new_points, self.get_last_point()])
self.set_points(new_points)
def insert_n_curves(self, n, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
# TODO, this should happen in insert_n_curves_to_point_list
if mob.has_new_path_started():
new_points = np.vstack([new_points, mob.get_last_point()])
mob.set_points(new_points)
return self
def insert_n_curves_to_point_list(self, n, points):
@ -768,7 +805,7 @@ class VMobject(Mobject):
# smaller quadratic curves
alphas = np.linspace(0, 1, n_inserts + 2)
for a1, a2 in zip(alphas, alphas[1:]):
new_points += partial_bezier_points(group, a1, a2)
new_points += partial_quadratic_bezier_points(group, a1, a2)
return np.vstack(new_points)
def align_rgbas(self, vmobject):
@ -789,24 +826,21 @@ class VMobject(Mobject):
"fill_rgbas",
"stroke_rgbas",
"stroke_width",
# "sheen_direction",
# "sheen_factor",
]
for attr in attrs:
arr = getattr(self, attr)
m1a = getattr(mobject1, attr)
m2a = getattr(mobject2, attr)
setattr(self, attr, interpolate(m1a, m2a, alpha))
arr[:] = interpolate(m1a, m2a, alpha)
# TODO, somehow do this using stroke_width changes
# so as to not have to change the point list
def pointwise_become_partial(self, vmobject, a, b):
assert(isinstance(vmobject, VMobject))
assert(len(self.points) >= len(vmobject.points))
if a <= 0 and b >= 1:
self.points[:] = vmobject.points
self.points[:] = vmobject.points[:]
return self
bezier_tuple = vmobject.get_bezier_tuples()
num_curves = len(bezier_tuple)
num_curves = self.get_num_curves()
nppc = self.n_points_per_curve
# Partial curve includes three portions:
# - A middle section, which matches the curve exactly
@ -815,27 +849,28 @@ class VMobject(Mobject):
lower_index, lower_residue = integer_interpolate(0, num_curves, a)
upper_index, upper_residue = integer_interpolate(0, num_curves, b)
i1 = nppc * lower_index
i2 = nppc * (lower_index + 1)
i3 = nppc * upper_index
i4 = nppc * (upper_index + 1)
new_point_list = []
if num_curves == 0:
self.points[:] = 0
return self
if lower_index == upper_index:
new_point_list.append(partial_bezier_points(
bezier_tuple[lower_index], lower_residue, upper_residue
))
tup = partial_quadratic_bezier_points(vmobject.points[i1:i2], lower_residue, upper_residue)
self.points[:i1] = tup[0]
self.points[i1:i4] = tup
self.points[i4:] = tup[2]
self.points[nppc:] = self.points[nppc - 1]
else:
new_point_list.append(partial_bezier_points(
bezier_tuple[lower_index], lower_residue, 1
))
for tup in bezier_tuple[lower_index + 1:upper_index]:
new_point_list.append(tup)
new_point_list.append(partial_bezier_points(
bezier_tuple[upper_index], 0, upper_residue
))
new_points = np.vstack(new_point_list)
self.points[:len(new_points)] = new_points
self.points[len(new_points):] = new_points[-1]
low_tup = partial_quadratic_bezier_points(vmobject.points[i1:i2], lower_residue, 1)
high_tup = partial_quadratic_bezier_points(vmobject.points[i3:i4], 0, upper_residue)
self.points[0:i1] = low_tup[0]
self.points[i1:i2] = low_tup
self.points[i2:i3] = vmobject.points[i2:i3]
self.points[i3:i4] = high_tup
self.points[i4:] = high_tup[2]
return self
def get_subcurve(self, a, b):
@ -881,14 +916,10 @@ class VMobject(Mobject):
stroke_data = []
fill_data = []
for submob in self.family_members_with_points():
stroke_width = submob.get_stroke_width()
stroke_opacity = submob.get_stroke_opacity()
fill_opacity = submob.get_fill_opacity()
if fill_opacity > 0:
fill_data.append(submob.get_fill_shader_data().tobytes())
if stroke_width > 0 and stroke_opacity > 0:
if submob.has_fill():
data = submob.get_fill_shader_data().tobytes()
fill_data.append(data)
if submob.has_stroke():
if submob.draw_stroke_behind_fill:
data = back_stroke_data
else:
@ -1061,8 +1092,7 @@ class VectorizedPoint(VMobject, Point):
class CurvesAsSubmobjects(VGroup):
def __init__(self, vmobject, **kwargs):
VGroup.__init__(self, **kwargs)
tuples = vmobject.get_bezier_tuples()
for tup in tuples:
for tup in vmobject.get_bezier_tuples():
part = VMobject()
part.set_points(tup)
part.match_style(vmobject)

View File

@ -1,11 +1,8 @@
from scipy import linalg
import numpy as np
from manimlib.constants import PI
from manimlib.utils.simple_functions import choose
from manimlib.utils.space_ops import rotate_vector
from manimlib.utils.space_ops import find_intersection
from manimlib.utils.space_ops import cross
from manimlib.utils.space_ops import cross2d
CLOSED_THRESHOLD = 0.001
@ -13,10 +10,14 @@ CLOSED_THRESHOLD = 0.001
def bezier(points):
n = len(points) - 1
return lambda t: sum([
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
for k, point in enumerate(points)
])
def result(t):
return sum([
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
for k, point in enumerate(points)
])
return result
def partial_bezier_points(points, a, b):
@ -43,6 +44,20 @@ def partial_bezier_points(points, a, b):
]
# Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount
def partial_quadratic_bezier_points(points, a, b):
def curve(t):
return points[0] * (1 - t) * (1 - t) + 2 * points[1] * t * (1 - t) + points[2] * t * t
# bezier(points)
h0 = curve(a) if a > 0 else points[0]
h2 = curve(b) if b < 1 else points[2]
h1_prime = (1 - a) * points[1] + a * points[2]
end_prop = (b - a) / (1. - a)
h1 = (1 - end_prop) * h0 + end_prop * h1_prime
return [h0, h1, h2]
# Linear interpolation variants
def interpolate(start, end, alpha):