Reframe Mobject, VMobject and SurfaceMobject with a data map

This commit is contained in:
Grant Sanderson
2021-01-11 10:57:23 -10:00
parent b3335c65fb
commit 9314dfd933
10 changed files with 276 additions and 428 deletions

View File

@ -24,7 +24,7 @@ class Rotating(Animation):
def interpolate_mobject(self, alpha):
for sm1, sm2 in self.get_all_families_zipped():
sm1.points[:] = sm2.points
sm1.set_points(sm2.get_points())
self.mobject.rotate(
alpha * self.angle,
axis=self.axis,

View File

@ -310,8 +310,7 @@ class Dot(Circle):
}
def __init__(self, point=ORIGIN, **kwargs):
Circle.__init__(self, arc_center=point, **kwargs)
self.lock_triangulation()
super().__init__(arc_center=point, **kwargs)
class SmallDot(Dot):
@ -327,7 +326,7 @@ class Ellipse(Circle):
}
def __init__(self, **kwargs):
Circle.__init__(self, **kwargs)
super().__init__(**kwargs)
self.set_width(self.width, stretch=True)
self.set_height(self.height, stretch=True)

View File

@ -15,7 +15,10 @@ from manimlib.utils.color import get_colormap_list
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import batch_by_property
from manimlib.utils.iterables import list_update
from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import resize_preserving_order
from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import set_array_by_interpolation
from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.space_ops import angle_of_vector
@ -34,7 +37,7 @@ class Mobject(object):
"""
CONFIG = {
"color": WHITE,
"dim": 3,
"dim": 3, # TODO, get rid of this
# Lighting parameters
# Positive gloss up to 1 makes it reflect the light.
"gloss": 0.0,
@ -42,9 +45,6 @@ class Mobject(object):
"shadow": 0.0,
# For shaders
"shader_folder": "",
# "vert_shader_file": "",
# "geom_shader_file": "",
# "frag_shader_file": "",
"render_primitive": moderngl.TRIANGLE_STRIP,
"texture_paths": None,
"depth_test": False,
@ -62,8 +62,8 @@ class Mobject(object):
self.parents = []
self.family = [self]
self.init_updaters()
self.init_data()
self.init_updaters()
self.init_points()
self.init_colors()
self.init_shader_data()
@ -84,35 +84,32 @@ class Mobject(object):
# Typically implemented in subclass, unless purposefully left blank
pass
# To sort out later
# Related to data dict
def init_data(self):
self.data = np.zeros(0, dtype=self.shader_dtype)
self.data = {
"points": np.zeros((0, 3)),
}
def resize_data(self, new_length):
if new_length != len(self.data):
self.data = np.resize(self.data, new_length)
def set_data(self, data):
for key in data:
self.data[key] = data[key]
def resize_points(self, new_length, resize_func=resize_array):
if new_length != len(self.data["points"]):
self.data["points"] = resize_func(self.data["points"], new_length)
def set_points(self, points):
self.resize_data(len(points))
self.data["point"] = points
self.resize_points(len(points))
self.data["points"][:] = points
def get_points(self):
return self.data["point"]
def get_all_point_arrays(self):
return [self.data["point"]]
def get_all_data_arrays(self):
return [self.data]
def get_data_array_attrs(self):
return ["data"]
return self.data["points"]
def clear_points(self):
self.resize_data(0)
self.resize_points(0)
def get_num_points(self):
return len(self.data)
return len(self.data["points"])
#
# Family matters
@ -212,8 +209,9 @@ class Mobject(object):
copy_mobject = copy.copy(self)
self.parents = parents
for attr in self.get_data_array_attrs():
setattr(copy_mobject, attr, getattr(self, attr).copy())
copy_mobject.data = dict(self.data)
for key in self.data:
copy_mobject.data[key] = self.data[key].copy()
copy_mobject.submobjects = []
copy_mobject.add(*[sm.copy() for sm in self.submobjects])
copy_mobject.match_updaters(self)
@ -224,7 +222,7 @@ class Mobject(object):
if isinstance(value, Mobject) and value in family and value is not self:
setattr(copy_mobject, attr, value.copy())
if isinstance(value, np.ndarray):
setattr(copy_mobject, attr, np.array(value))
setattr(copy_mobject, attr, value.copy())
if isinstance(value, ShaderWrapper):
setattr(copy_mobject, attr, value.copy())
return copy_mobject
@ -343,8 +341,7 @@ class Mobject(object):
def shift(self, *vectors):
total_vector = reduce(op.add, vectors)
for mob in self.get_family():
for arr in mob.get_all_point_arrays():
arr += total_vector
mob.set_points(mob.get_points() + total_vector)
return self
def scale(self, scale_factor, **kwargs):
@ -458,8 +455,8 @@ class Mobject(object):
about_edge = ORIGIN
about_point = self.get_bounding_box_point(about_edge)
for mob in self.family_members_with_points():
for arr in mob.get_all_point_arrays():
arr[:] = func(arr - about_point) + about_point
points = mob.get_points()
points[:] = func(points - about_point) + about_point
return self
# Positioning methods
@ -512,8 +509,7 @@ class Mobject(object):
else:
aligner = self
point_to_align = aligner.get_bounding_box_point(aligned_edge - direction)
self.shift((target_point - point_to_align +
buff * direction) * coor_mask)
self.shift((target_point - point_to_align + buff * direction) * coor_mask)
return self
def shift_onto_screen(self, **kwargs):
@ -807,11 +803,9 @@ class Mobject(object):
else:
return getattr(self, array_attr)
def get_all_points(self): # TODO, use get_all_point_arrays?
def get_all_points(self):
if self.submobjects:
return np.vstack([
sm.get_points() for sm in self.get_family()
])
return np.vstack([sm.get_points() for sm in self.get_family()])
else:
return self.get_points()
@ -1107,20 +1101,25 @@ class Mobject(object):
self.null_point_align(mobject) # Needed?
self.align_submobjects(mobject)
for mob1, mob2 in zip(self.get_family(), mobject.get_family()):
# Separate out how points are treated so that subclasses
# can handle that case differently if they choose
mob1.align_points(mob2)
for key in mob1.data:
if key == "points":
continue
arr1 = mob1.data[key]
arr2 = mob2.data[key]
if len(arr2) > len(arr1):
mob1.data[key] = resize_preserving_order(arr1, len(arr2))
elif len(arr1) > len(arr2):
mob2.data[key] = resize_preserving_order(arr2, len(arr1))
def align_points(self, mobject):
count1 = self.get_num_points()
count2 = mobject.get_num_points()
if count1 < count2:
self.align_points_with_larger(mobject)
elif count2 < count1:
mobject.align_points_with_larger(self)
max_len = max(self.get_num_points(), mobject.get_num_points())
self.resize_points(max_len, resize_func=resize_preserving_order)
mobject.resize_points(max_len, resize_func=resize_preserving_order)
return self
def align_points_with_larger(self, larger_mobject):
raise Exception("Not implemented")
def align_submobjects(self, mobject):
mob1 = self
mob2 = mobject
@ -1188,19 +1187,16 @@ class Mobject(object):
Turns self into an interpolation between mobject1
and mobject2.
"""
mobs = [self, mobject1, mobject2]
# for arr, arr1, arr2 in zip(*(m.get_all_data_arrays() for m in mobs)):
# arr[:] = interpolate(arr1, arr2, alpha)
# if path_func is not straight_path:
for arr, arr1, arr2 in zip(*(m.get_all_point_arrays() for m in mobs)):
arr[:] = path_func(arr1, arr2, alpha)
# self.interpolate_color(mobject1, mobject2, alpha)
for key in self.data:
func = path_func if key == "points" else interpolate
self.data[key][:] = func(
mobject1.data[key],
mobject2.data[key],
alpha
)
self.interpolate_light_style(mobject1, mobject2, alpha) # TODO, interpolate uniforms instaed
return self
def interpolate_color(self, mobject1, mobject2, alpha):
pass # To implement in subclass
def interpolate_light_style(self, mobject1, mobject2, alpha):
g0 = self.get_gloss()
g1 = mobject1.get_gloss()
@ -1236,8 +1232,7 @@ class Mobject(object):
"""
self.align_submobjects(mobject)
for sm1, sm2 in zip(self.get_family(), mobject.get_family()):
for arr1, arr2 in zip(sm1.get_all_data_arrays(), sm2.get_all_data_arrays()):
arr1[:] = arr2
sm1.set_data(sm2.data)
return self
def cleanup_from_animation(self):
@ -1320,9 +1315,10 @@ class Mobject(object):
# For shader data
def init_shader_data(self):
self.shader_data = np.zeros(len(self.get_points()), dtype=self.shader_dtype)
self.shader_indices = None
self.shader_wrapper = ShaderWrapper(
vert_data=self.data,
vert_data=self.shader_data,
shader_folder=self.shader_folder,
texture_paths=self.texture_paths,
depth_test=self.depth_test,
@ -1333,18 +1329,18 @@ class Mobject(object):
self.shader_wrapper.refresh_id()
return self
# def get_blank_shader_data_array(self, size, name="data"):
# # If possible, try to populate an existing array, rather
# # than recreating it each frame
# arr = getattr(self, name)
# if arr.size != size:
# new_arr = np.resize(arr, size)
# setattr(self, name, new_arr)
# return new_arr
# return arr
def get_blank_shader_data_array(self, size, name="shader_data"):
# If possible, try to populate an existing array, rather
# than recreating it each frame
arr = getattr(self, name)
if arr.size != size:
new_arr = resize_array(arr, size)
setattr(self, name, new_arr)
return new_arr
return arr
def get_shader_wrapper(self):
self.shader_wrapper.vert_data = self.data # TODO
self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.uniforms = self.get_shader_uniforms()
self.shader_wrapper.depth_test = self.depth_test
@ -1367,6 +1363,10 @@ class Mobject(object):
result.append(shader_wrapper)
return result
def get_shader_data(self):
# May be different for subclasses
return self.shader_data
def get_shader_uniforms(self):
return {
"is_fixed_in_frame": float(self.is_fixed_in_frame),
@ -1411,7 +1411,7 @@ class Point(Mobject):
return self.artificial_height
def get_location(self):
return np.array(self.get_points()[0])
return self.get_points()[0].copy()
def get_bounding_box_point(self, *args, **kwargs):
return self.get_location()

View File

@ -281,6 +281,7 @@ class SVGMobject(VMobject):
matrix[:, 1] *= -1
for mob in mobject.family_members_with_points():
# TODO, directly apply matrix?
mob.set_points(np.dot(mob.get_points(), matrix))
mobject.shift(x * RIGHT + y * UP)
except:

View File

@ -7,7 +7,7 @@ from manimlib.constants import ORIGIN
from manimlib.mobject.types.point_cloud_mobject import PMobject
from manimlib.mobject.geometry import DEFAULT_DOT_RADIUS
from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import stretch_array_to_length
from manimlib.utils.iterables import resize_preserving_order
class DotCloud(PMobject):
@ -32,7 +32,7 @@ class DotCloud(PMobject):
def set_points(self, points):
super().set_points(points)
self.radii = stretch_array_to_length(self.radii, len(points))
self.radii = resize_preserving_order(self.radii, len(points))
return self
def set_points_by_grid(self, n_rows, n_cols, height=None, width=None):
@ -58,7 +58,7 @@ class DotCloud(PMobject):
if isinstance(radii, numbers.Number):
self.radii[:] = radii
else:
self.radii = stretch_array_to_length(radii, len(self.points))
self.radii = resize_preserving_order(radii, len(self.points))
return self
def scale(self, scale_factor, scale_radii=True, **kwargs):

View File

@ -4,7 +4,7 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.color import color_gradient
from manimlib.utils.color import color_to_rgba
from manimlib.utils.color import rgba_to_color
from manimlib.utils.iterables import stretch_array_to_length
from manimlib.utils.iterables import resize_preserving_order
class PMobject(Mobject):
@ -18,7 +18,7 @@ class PMobject(Mobject):
def set_points(self, points):
self.points = points
self.rgbas = stretch_array_to_length(self.rgbas, len(points))
self.rgbas = resize_preserving_order(self.rgbas, len(points))
return self
def add_points(self, points, rgbas=None, color=None, alpha=1):
@ -91,7 +91,7 @@ class PMobject(Mobject):
return self
def match_colors(self, pmobject):
self.rgbas[:] = stretch_array_to_length(pmobject.rgbas, len(self.points))
self.rgbas[:] = resize_preserving_order(pmobject.rgbas, len(self.points))
return self
def filter_out(self, condition):
@ -138,14 +138,6 @@ class PMobject(Mobject):
return self.points[index]
# Alignment
def align_points_with_larger(self, larger_mobject):
assert(isinstance(larger_mobject, PMobject))
self.apply_over_attr_arrays(
lambda a: stretch_array_to_length(
a, larger_mobject.get_num_points()
)
)
def interpolate_color(self, mobject1, mobject2, alpha):
self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha)
return self

View File

@ -44,7 +44,12 @@ class ParametricSurface(Mobject):
self.uv_func = uv_func
self.compute_triangle_indices()
super().__init__(**kwargs)
self.sort_faces_back_to_front()
def init_data(self):
self.data = {
"points": np.zeros((0, 3)),
"rgba": np.zeros((1, 4)),
}
def init_points(self):
dim = self.dim
@ -65,7 +70,10 @@ class ParametricSurface(Mobject):
# infinitesimal nudged values alongside the original values. This way, one
# can perform all the manipulations they'd like to the surface, and normals
# are still easily recoverable.
self.points = np.vstack(point_lists)
self.set_points(np.vstack(point_lists))
def init_colors(self):
self.set_color(self.color, self.opacity)
def compute_triangle_indices(self):
# TODO, if there is an event which changes
@ -88,13 +96,10 @@ class ParametricSurface(Mobject):
def get_triangle_indices(self):
return self.triangle_indices
def init_colors(self):
self.rgbas = np.zeros((1, 4))
self.set_color(self.color, self.opacity)
def get_surface_points_and_nudged_points(self):
k = len(self.points) // 3
return self.points[:k], self.points[k:2 * k], self.points[2 * k:]
points = self.get_points()
k = len(points) // 3
return points[:k], points[k:2 * k], points[2 * k:]
def get_unit_normals(self):
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
@ -104,40 +109,38 @@ class ParametricSurface(Mobject):
)
return normalize_along_axis(normals, 1)
def set_color(self, color, opacity=1.0, family=True):
def set_color(self, color, opacity=None, family=True):
# TODO, allow for multiple colors
if opacity is None:
opacity = self.data["rgba"][0, 3]
rgba = color_to_rgba(color, opacity)
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.rgbas[:] = rgba
mob.data["rgba"][:] = rgba
return self
def get_color(self):
return rgb_to_color(self.rgbas[0, :3])
return rgb_to_color(self.data["rgba"][0, :3])
def set_opacity(self, opacity, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.rgbas[:, 3] = opacity
return self
def interpolate_color(self, mobject1, mobject2, alpha):
self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha)
mob.data["rgba"][:, 3] = opacity
return self
def pointwise_become_partial(self, smobject, a, b, axis=None):
if axis is None:
axis = self.prefered_creation_axis
assert(isinstance(smobject, ParametricSurface))
self.points[:] = smobject.points[:]
if a <= 0 and b >= 1:
self.set_points(smobject.points)
return self
nu, nv = smobject.resolution
self.points[:] = np.vstack([
self.set_points(np.vstack([
self.get_partial_points_array(arr, a, b, (nu, nv, 3), axis=axis)
for arr in self.get_surface_points_and_nudged_points()
])
for arr in smobject.get_surface_points_and_nudged_points()
]))
return self
def get_partial_points_array(self, points, a, b, resolution, axis):
@ -178,7 +181,8 @@ class ParametricSurface(Mobject):
return data
def fill_in_shader_color_info(self, data):
data["color"] = self.rgbas
# TODO, what if len(self.data["rgba"]) > 1?
data["color"] = self.data["rgba"]
return data
def get_shader_vert_indices(self):
@ -195,7 +199,7 @@ class SGroup(ParametricSurface):
self.add(*parametric_surfaces)
def init_points(self):
self.points = np.zeros((0, 3))
pass # Needed?
class TexturedSurface(ParametricSurface):
@ -229,40 +233,40 @@ class TexturedSurface(ParametricSurface):
self.u_range = uv_surface.u_range
self.v_range = uv_surface.v_range
self.resolution = uv_surface.resolution
self.gloss = self.uv_surface.gloss
super().__init__(self.uv_func, **kwargs)
def init_points(self):
self.points = self.uv_surface.points
# Init im_coords
def init_data(self):
nu, nv = self.uv_surface.resolution
u_range = np.linspace(0, 1, nu)
v_range = np.linspace(1, 0, nv) # Reverse y-direction
uv_grid = np.array([[u, v] for u in u_range for v in v_range])
self.im_coords = uv_grid
self.data = {
"points": self.uv_surface.get_points(),
"im_coords": np.array([
[u, v]
for u in np.linspace(0, 1, nu)
for v in np.linspace(1, 0, nv) # Reverse y-direction
]),
"opacity": np.array([self.uv_surface.data["rgba"][:, 3]]),
}
def init_colors(self):
self.opacity = self.uv_surface.rgbas[:, 3]
self.gloss = self.uv_surface.gloss
def interpolate_color(self, mobject1, mobject2, alpha):
# TODO, handle multiple textures
self.opacity = interpolate(mobject1.opacity, mobject2.opacity, alpha)
return self
pass
def set_opacity(self, opacity, family=True):
self.opacity = opacity
if family:
for sm in self.submobjects:
sm.set_opacity(opacity, family)
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.data["opacity"][:] = opacity
return self
def pointwise_become_partial(self, tsmobject, a, b, axis=1):
super().pointwise_become_partial(tsmobject, a, b, axis)
self.im_coords[:] = tsmobject.im_coords
im_coords = self.data["im_coords"]
im_coords[:] = tsmobject.data["im_coords"]
if a <= 0 and b >= 1:
return self
nu, nv = tsmobject.resolution
self.im_coords[:] = self.get_partial_points_array(self.im_coords, a, b, (nu, nv, 2), axis)
im_coords[:] = self.get_partial_points_array(
im_coords, a, b, (nu, nv, 2), axis
)
return self
def get_shader_uniforms(self):
@ -271,6 +275,6 @@ class TexturedSurface(ParametricSurface):
return result
def fill_in_shader_color_info(self, data):
data["im_coords"] = self.im_coords
data["opacity"] = self.opacity
data["im_coords"] = self.data["im_coords"]
data["opacity"] = self.data["opacity"]
return data

View File

@ -2,7 +2,6 @@ import itertools as it
import operator as op
import moderngl
from colour import Color
from functools import reduce
from manimlib.constants import *
@ -17,10 +16,12 @@ from manimlib.utils.bezier import set_array_by_interpolation
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import partial_quadratic_bezier_points
from manimlib.utils.color import color_to_rgba
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.iterables import make_even
from manimlib.utils.iterables import stretch_array_to_length
from manimlib.utils.iterables import stretch_array_to_length_with_interpolation
from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import resize_preserving_order
from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.iterables import listify
from manimlib.utils.paths import straight_path
from manimlib.utils.space_ops import angle_between_vectors
@ -77,68 +78,27 @@ class VMobject(Mobject):
def __init__(self, **kwargs):
self.unit_normal_locked = False
self.triangulation_locked = False
self.needs_new_triangulation = True
super().__init__(**kwargs)
self.lock_unit_normal(family=False)
self.lock_triangulation(family=False)
def get_group_class(self):
return VGroup
# To sort out later
def init_data(self):
self.fill_data = np.zeros(0, dtype=self.fill_dtype)
self.stroke_data = np.zeros(0, dtype=self.stroke_dtype)
def resize_data(self, new_length):
self.stroke_data = np.resize(self.stroke_data, new_length)
self.fill_data = np.resize(self.fill_data, new_length)
self.fill_data["vert_index"][:, 0] = range(new_length)
self.data = {
"points": np.zeros((0, 3)),
"fill_rgba": np.zeros((1, 4)),
"stroke_rgba": np.zeros((1, 4)),
"stroke_width": np.zeros((1, 1)),
}
def set_points(self, points):
if len(points) != len(self.stroke_data):
self.resize_data(len(points))
nppc = self.n_points_per_curve
self.stroke_data["point"] = points
self.stroke_data["prev_point"][:nppc] = points[-nppc:]
self.stroke_data["prev_point"][nppc:] = points[:-nppc]
self.stroke_data["next_point"][:-nppc] = points[nppc:]
self.stroke_data["next_point"][-nppc:] = points[:nppc]
self.fill_data["point"] = points
# # TODO, only do conditionally
# unit_normal = self.get_unit_normal()
# self.stroke_data["unit_normal"] = unit_normal
# self.fill_data["unit_normal"] = unit_normal
# self.refresh_triangulation()
def get_points(self):
return self.stroke_data["point"]
def get_all_point_arrays(self):
return [
self.fill_data["point"],
self.stroke_data["point"],
self.stroke_data["prev_point"],
self.stroke_data["next_point"],
]
def get_all_data_arrays(self):
return [self.fill_data, self.stroke_data]
def get_data_array_attrs(self):
return ["fill_data", "stroke_data"]
def get_num_points(self):
return len(self.stroke_data)
super().set_points(points)
self.refresh_triangulation()
# Colors
def init_colors(self):
self.fill_rgbas = np.zeros((1, 4))
self.stroke_rgbas = np.zeros((1, 4))
self.set_fill(
color=self.fill_color or self.color,
opacity=self.fill_opacity,
@ -153,69 +113,45 @@ class VMobject(Mobject):
self.set_flat_stroke(self.flat_stroke)
return self
def generate_rgba_array(self, color, opacity):
"""
First arg can be either a color, or a tuple/list of colors.
Likewise, opacity can either be a float, or a tuple of floats.
"""
colors = listify(color)
opacities = listify(opacity)
return np.array([
color_to_rgba(c, o)
for c, o in zip(*make_even(colors, opacities))
])
def update_rgbas_array(self, array_name, color, opacity):
rgbas = self.generate_rgba_array(color or BLACK, opacity or 0)
# Match up current rgbas array with the newly calculated
# one. 99% of the time they'll be the same.
curr_rgbas = getattr(self, array_name)
if len(curr_rgbas) < len(rgbas):
curr_rgbas = stretch_array_to_length(curr_rgbas, len(rgbas))
setattr(self, array_name, curr_rgbas)
elif len(rgbas) < len(curr_rgbas):
rgbas = stretch_array_to_length(rgbas, len(curr_rgbas))
# Only update rgb if color was not None, and only
# update alpha channel if opacity was passed in
if color is not None:
curr_rgbas[:, :3] = rgbas[:, :3]
def set_rgba_array(self, name, color, opacity, family=True):
# TODO, account for if color or opacity are tuples
rgb = color_to_rgb(color) if color else None
mobs = self.get_family() if family else [self]
for mob in mobs:
if rgb is not None:
mob.data[name][:, :3] = rgb
if opacity is not None:
curr_rgbas[:, 3] = rgbas[:, 3]
mob.data[name][:, 3] = opacity
return self
def set_fill(self, color=None, opacity=None, family=True):
if family:
for sm in self.submobjects:
sm.set_fill(color, opacity, family)
self.update_rgbas_array("fill_rgbas", color, opacity)
return self
self.set_rgba_array('fill_rgba', color, opacity, family)
def set_stroke(self, color=None, width=None, opacity=None,
background=None, family=True):
if family:
for sm in self.submobjects:
sm.set_stroke(color, width, opacity, background, family)
self.update_rgbas_array("stroke_rgbas", color, opacity)
def set_stroke(self, color=None, width=None, opacity=None, background=None, family=True):
self.set_rgba_array('stroke_rgba', color, opacity, family)
mobs = self.get_family() if family else [self]
for mob in mobs:
if width is not None:
self.stroke_width = np.array(listify(width), dtype=float)
# TODO, account for if width is an array
mob.data['stroke_width'][:] = width
if background is not None:
self.draw_stroke_behind_fill = background
mob.draw_stroke_behind_fill = background
return self
def set_style(self,
fill_color=None,
fill_opacity=None,
fill_rgbas=None,
fill_rgba=None,
stroke_color=None,
stroke_opacity=None,
stroke_rgbas=None,
stroke_rgba=None,
stroke_width=None,
gloss=None,
shadow=None,
background_image_file=None,
family=True):
if fill_rgbas is not None:
self.fill_rgbas = np.array(fill_rgbas)
if fill_rgba is not None:
self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
else:
self.set_fill(
color=fill_color,
@ -223,10 +159,9 @@ class VMobject(Mobject):
family=family
)
if stroke_rgbas is not None:
self.stroke_rgbas = np.array(stroke_rgbas)
if stroke_width is not None:
self.stroke_width = np.array(listify(stroke_width))
if stroke_rgba is not None:
self.data['stroke_rgba'] = resize_with_interpolation(stroke_rgba, len(fill_rgba))
self.set_stroke(width=stroke_width)
else:
self.set_stroke(
color=stroke_color,
@ -239,31 +174,19 @@ class VMobject(Mobject):
self.set_gloss(gloss, family=family)
if shadow is not None:
self.set_shadow(shadow, family=family)
if background_image_file:
self.color_using_background_image(background_image_file)
return self
def get_style(self):
return {
"fill_rgbas": self.get_fill_rgbas(),
"stroke_rgbas": self.get_stroke_rgbas(),
"stroke_width": self.stroke_width,
"fill_rgba": self.data['fill_rgba'],
"stroke_rgba": self.data['stroke_rgba'],
"stroke_width": self.data['stroke_width'],
"gloss": self.get_gloss(),
"shadow": self.get_shadow(),
"background_image_file": self.get_background_image_file(),
}
def match_style(self, vmobject, family=True):
for name, value in vmobject.get_style().items():
if isinstance(value, np.ndarray):
curr = getattr(self, name)
if curr.size == value.size:
curr[:] = value[:]
else:
setattr(self, name, np.array(value))
else:
setattr(self, name, value)
self.set_style(**vmobject.get_style(), family=False)
if family:
# Does its best to match up submobject lists, and
# match styles accordingly
@ -299,12 +222,29 @@ class VMobject(Mobject):
super().fade(darkness, family)
return self
def get_fill_rgbas(self):
try:
return self.fill_rgbas
except AttributeError:
return np.zeros((1, 4))
def get_fill_colors(self):
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['fill_rgba']
]
def get_fill_opacities(self):
return self.data['fill_rgba'][:, 3]
def get_stroke_colors(self):
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['stroke_rgba']
]
def get_stroke_opacities(self):
return self.data['stroke_rgba'][:, 3]
def get_stroke_widths(self):
return self.data['stroke_width']
# TODO, it's weird for these to return the first of various lists
# rather than the full information
def get_fill_color(self):
"""
If there are multiple colors (for gradient)
@ -319,62 +259,25 @@ class VMobject(Mobject):
"""
return self.get_fill_opacities()[0]
def get_fill_colors(self):
return [
Color(rgb=rgba[:3])
for rgba in self.get_fill_rgbas()
]
def get_fill_opacities(self):
return self.get_fill_rgbas()[:, 3]
def get_stroke_rgbas(self):
try:
return self.stroke_rgbas
except AttributeError:
return np.zeros((1, 4))
# TODO, it's weird for these to return the first of various lists
# rather than the full information
def get_stroke_color(self):
return self.get_stroke_colors()[0]
def get_stroke_width(self):
return self.stroke_width[0]
return self.get_stroke_widths()[0]
def get_stroke_opacity(self):
return self.get_stroke_opacities()[0]
def get_stroke_colors(self):
return [
rgb_to_hex(rgba[:3])
for rgba in self.get_stroke_rgbas()
]
def get_stroke_opacities(self):
return self.get_stroke_rgbas()[:, 3]
def get_color(self):
if np.all(self.get_fill_opacities() == 0):
if self.has_stroke():
return self.get_stroke_color()
return self.get_fill_color()
def has_stroke(self):
if len(self.stroke_width) == 1:
if self.stroke_width == 0:
return False
elif not self.stroke_width.any():
return False
alphas = self.stroke_rgbas[:, 3]
if len(alphas) == 1:
return alphas[0] > 0
return alphas.any()
return any(self.get_stroke_widths()) and any(self.get_stroke_opacities())
def has_fill(self):
alphas = self.fill_rgbas[:, 3]
if len(alphas) == 1:
return alphas[0] > 0
return alphas.any()
return any(self.get_fill_opacities())
def get_opacity(self):
if self.has_fill():
@ -382,45 +285,14 @@ class VMobject(Mobject):
return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke=True, family=True):
self.flat_stroke = flat_stroke
if family:
for submob in self.submobjects:
submob.set_flat_stroke(flat_stroke, family)
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.flat_stroke = flat_stroke
return self
def get_flat_stroke(self):
return self.flat_stroke
# TODO, this currently does nothing
def color_using_background_image(self, background_image_file):
self.background_image_file = background_image_file
self.set_color(WHITE)
for submob in self.submobjects:
submob.color_using_background_image(background_image_file)
return self
def get_background_image_file(self):
return self.background_image_file
def match_background_image_file(self, vmobject):
self.color_using_background_image(vmobject.get_background_image_file())
return self
def stretched_style_array_matching_points(self, array):
new_len = self.get_num_points()
long_arr = stretch_array_to_length_with_interpolation(
array, 1 + 2 * (new_len // 3)
)
shape = array.shape
if len(shape) > 1:
result = np.zeros((new_len, shape[1]))
else:
result = np.zeros(new_len)
result[0::3] = long_arr[0:-1:2]
result[1::3] = long_arr[1::2]
result[2::3] = long_arr[2::2]
return result
# Points
def set_anchors_and_handles(self, anchors1, handles, anchors2):
assert(len(anchors1) == len(handles) == len(anchors2))
@ -436,7 +308,8 @@ class VMobject(Mobject):
# TODO, check that number new points is a multiple of 4?
# or else that if self.get_num_points() % 4 == 1, then
# len(new_points) % 4 == 3?
self.set_points(np.vstack([self.get_points(), new_points]))
self.resize_points(self.get_num_points() + len(new_points))
self.data["points"][-len(new_points):] = new_points
return self
def start_new_path(self, point):
@ -784,7 +657,6 @@ class VMobject(Mobject):
# Alignment
def align_points(self, vmobject):
self.align_rgbas(vmobject)
if self.get_num_points() == len(vmobject.get_points()):
return
@ -871,39 +743,20 @@ class VMobject(Mobject):
new_points += partial_quadratic_bezier_points(group, a1, a2)
return np.vstack(new_points)
def align_rgbas(self, vmobject):
attrs = ["fill_rgbas", "stroke_rgbas"]
for attr in attrs:
a1 = getattr(self, attr)
a2 = getattr(vmobject, attr)
if len(a1) > len(a2):
new_a2 = stretch_array_to_length(a2, len(a1))
setattr(vmobject, attr, new_a2)
elif len(a2) > len(a1):
new_a1 = stretch_array_to_length(a1, len(a2))
setattr(self, attr, new_a1)
def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs):
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill():
tri1 = mobject1.get_triangulation()
tri2 = mobject2.get_triangulation()
if len(tri1) != len(tri1) or not all(tri1 == tri2):
self.refresh_triangulation()
return self
def interpolate_color(self, mobject1, mobject2, alpha):
attrs = [
"fill_rgbas",
"stroke_rgbas",
"stroke_width",
]
for attr in attrs:
set_array_by_interpolation(
getattr(self, attr),
getattr(mobject1, attr),
getattr(mobject2, attr),
alpha
)
def pointwise_become_partial(self, vmobject, a, b):
assert(isinstance(vmobject, VMobject))
self.set_points(vmobject.get_points())
if a <= 0 and b >= 1:
return self
num_curves = self.get_num_curves()
num_curves = vmobject.get_num_curves()
nppc = self.n_points_per_curve
# Partial curve includes three portions:
@ -918,26 +771,26 @@ class VMobject(Mobject):
i3 = nppc * upper_index
i4 = nppc * (upper_index + 1)
points = self.get_points()
vm_points = vmobject.get_points()
new_points = vm_points.copy()
if num_curves == 0:
points[:] = 0
new_points[:] = 0
return self
if lower_index == upper_index:
tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, upper_residue)
points[:i1] = tup[0]
points[i1:i4] = tup
points[i4:] = tup[2]
points[nppc:] = points[nppc - 1]
new_points[:i1] = tup[0]
new_points[i1:i4] = tup
new_points[i4:] = tup[2]
new_points[nppc:] = new_points[nppc - 1]
else:
low_tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, 1)
high_tup = partial_quadratic_bezier_points(vm_points[i3:i4], 0, upper_residue)
points[0:i1] = low_tup[0]
points[i1:i2] = low_tup
# Keep points i2:i3 as they are
points[i3:i4] = high_tup
points[i4:] = high_tup[2]
self.set_points(points)
new_points[0:i1] = low_tup[0]
new_points[i1:i2] = low_tup
# Keep new_points i2:i3 as they are
new_points[i3:i4] = high_tup
new_points[i4:] = high_tup[2]
self.set_points(new_points)
return self
def get_subcurve(self, a, b):
@ -945,14 +798,10 @@ class VMobject(Mobject):
vmob.pointwise_become_partial(self, a, b)
return vmob
def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path):
super().interpolate(mobject1, mobject2, alpha, path_func)
if not np.all(mobject1.get_triangulation() == mobject2.get_triangulation()):
self.refresh_triangulation()
return self
# For shaders
def init_shader_data(self):
self.fill_data = np.zeros(0, dtype=self.fill_dtype)
self.stroke_data = np.zeros(0, dtype=self.stroke_dtype)
self.fill_shader_wrapper = ShaderWrapper(
vert_data=self.fill_data,
vert_indices=np.zeros(0, dtype='i4'),
@ -1019,39 +868,27 @@ class VMobject(Mobject):
return result
def get_stroke_shader_data(self):
# TODO, make even simpler after fixing colors
rgbas = self.get_stroke_rgbas()
if len(rgbas) > 1:
rgbas = self.stretched_style_array_matching_points(rgbas)
points = self.get_points()
if len(self.stroke_data) != len(points):
self.stroke_data = resize_array(self.stroke_data, len(points))
# TODO, account for when self.data["stroke_width"] and self.data["stroke_rgba"]
# have length greater than 1
stroke_width = self.stroke_width
if len(stroke_width) > 1:
stroke_width = self.stretched_style_array_matching_points(stroke_width)
nppc = self.n_points_per_curve
self.stroke_data["point"] = points
self.stroke_data["prev_point"][:nppc] = points[-nppc:]
self.stroke_data["prev_point"][nppc:] = points[:-nppc]
self.stroke_data["next_point"][:-nppc] = points[nppc:]
self.stroke_data["next_point"][-nppc:] = points[:nppc]
data = self.stroke_data
data["stroke_width"][:, 0] = stroke_width
data["color"] = rgbas
return data
def lock_triangulation(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.triangulation_locked = False
mob.saved_triangulation = mob.get_triangulation()
mob.triangulation_locked = True
return self
def unlock_triangulation(self):
for sm in self.get_family():
sm.triangulation_locked = False
return self
self.stroke_data["unit_normal"] = self.get_unit_normal()
self.stroke_data["stroke_width"] = self.data["stroke_width"]
self.stroke_data["color"] = self.data["stroke_rgba"]
return self.stroke_data
def refresh_triangulation(self):
for mob in self.get_family():
if mob.triangulation_locked:
mob.triangulation_locked = False
mob.saved_triangulation = mob.get_triangulation()
mob.triangulation_locked = True
mob.needs_new_triangulation = True
return self
def get_triangulation(self, normal_vector=None):
@ -1061,13 +898,14 @@ class VMobject(Mobject):
if normal_vector is None:
normal_vector = self.get_unit_normal()
if self.triangulation_locked:
return self.saved_triangulation
if not self.needs_new_triangulation:
return self.saved_traignulation
points = self.get_points()
if len(points) <= 1:
return np.zeros(0, dtype='i4')
self.saved_traignulation == np.zeros(0, dtype='i4')
return self.saved_traignulation
# Rotate points such that unit normal vector is OUT
# TODO, 99% of the time this does nothing. Do a check for that?
@ -1104,14 +942,20 @@ class VMobject(Mobject):
inner_tri_indices = inner_vert_indices[earclip_triangulation(inner_verts, rings)]
tri_indices = np.hstack([indices, inner_tri_indices])
self.saved_traignulation = tri_indices
self.needs_new_triangulation = False
return tri_indices
def get_fill_shader_data(self):
# TODO, make simpler
rgbas = self.get_fill_rgbas()[:1]
data = self.fill_data
data["color"] = rgbas
return data
points = self.get_points()
if len(self.fill_data) != len(points):
self.fill_data = resize_array(self.fill_data, len(points))
self.fill_data["vert_index"][:, 0] = range(len(points))
self.fill_data["point"] = points
self.fill_data["unit_normal"] = self.get_unit_normal()
self.fill_data["color"] = self.data["fill_rgba"]
return self.fill_data
def get_fill_shader_vert_indices(self):
return self.get_triangulation()
@ -1121,11 +965,11 @@ class VGroup(VMobject):
def __init__(self, *vmobjects, **kwargs):
if not all([isinstance(m, VMobject) for m in vmobjects]):
raise Exception("All submobjects must be of type VMobject")
VMobject.__init__(self, **kwargs)
super().__init__(**kwargs)
self.add(*vmobjects)
class VectorizedPoint(VMobject, Point):
class VectorizedPoint(Point, VMobject):
CONFIG = {
"color": BLACK,
"fill_opacity": 0,
@ -1135,13 +979,13 @@ class VectorizedPoint(VMobject, Point):
}
def __init__(self, location=ORIGIN, **kwargs):
VMobject.__init__(self, **kwargs)
super().__init__(**kwargs)
self.set_points(np.array([location]))
class CurvesAsSubmobjects(VGroup):
def __init__(self, vmobject, **kwargs):
VGroup.__init__(self, **kwargs)
super().__init__(**kwargs)
for tup in vmobject.get_bezier_tuples():
part = VMobject()
part.set_points(tup)
@ -1157,7 +1001,7 @@ class DashedVMobject(VMobject):
}
def __init__(self, vmobject, **kwargs):
VMobject.__init__(self, **kwargs)
super().__init__(**kwargs)
num_dashes = self.num_dashes
ps_ratio = self.positive_space_ratio
if num_dashes > 0:

View File

@ -74,8 +74,8 @@ def interpolate(start, end, alpha):
sys.exit(2)
def set_array_by_interpolation(arr, arr1, arr2, alpha):
arr[:] = interpolate(arr1, arr2, alpha)
def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate):
arr[:] = interp_func(arr1, arr2, alpha)
return arr

View File

@ -80,15 +80,23 @@ def listify(obj):
return [obj]
def stretch_array_to_length(nparray, length):
# TODO, rename to "resize"?
def resize_array(nparray, length):
return np.resize(nparray, (length, *nparray.shape[1:]))
def resize_preserving_order(nparray, length):
if len(nparray) == 0:
return np.zeros((length, *nparray.shape[1:]))
if len(nparray) == length:
return nparray
indices = np.arange(length) * len(nparray) // length
return nparray[indices]
def stretch_array_to_length_with_interpolation(nparray, length):
curr_len = len(nparray)
cont_indices = np.linspace(0, curr_len - 1, length)
def resize_with_interpolation(nparray, length):
if len(nparray) == length:
return nparray
cont_indices = np.linspace(0, len(nparray) - 1, length)
return np.array([
(1 - a) * nparray[lh] + a * nparray[rh]
for ci in cont_indices