From 9314dfd933d19e6bc8e1884553db615fff3a901b Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 11 Jan 2021 10:57:23 -1000 Subject: [PATCH] Reframe Mobject, VMobject and SurfaceMobject with a data map --- manimlib/animation/rotation.py | 2 +- manimlib/mobject/geometry.py | 5 +- manimlib/mobject/mobject.py | 140 +++--- manimlib/mobject/svg/svg_mobject.py | 1 + manimlib/mobject/types/dot_cloud.py | 6 +- manimlib/mobject/types/point_cloud_mobject.py | 14 +- manimlib/mobject/types/surface.py | 92 ++-- manimlib/mobject/types/vectorized_mobject.py | 422 ++++++------------ manimlib/utils/bezier.py | 4 +- manimlib/utils/iterables.py | 18 +- 10 files changed, 276 insertions(+), 428 deletions(-) diff --git a/manimlib/animation/rotation.py b/manimlib/animation/rotation.py index e668bba2..ad9f2b7d 100644 --- a/manimlib/animation/rotation.py +++ b/manimlib/animation/rotation.py @@ -24,7 +24,7 @@ class Rotating(Animation): def interpolate_mobject(self, alpha): for sm1, sm2 in self.get_all_families_zipped(): - sm1.points[:] = sm2.points + sm1.set_points(sm2.get_points()) self.mobject.rotate( alpha * self.angle, axis=self.axis, diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index 43496c6a..47392494 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -310,8 +310,7 @@ class Dot(Circle): } def __init__(self, point=ORIGIN, **kwargs): - Circle.__init__(self, arc_center=point, **kwargs) - self.lock_triangulation() + super().__init__(arc_center=point, **kwargs) class SmallDot(Dot): @@ -327,7 +326,7 @@ class Ellipse(Circle): } def __init__(self, **kwargs): - Circle.__init__(self, **kwargs) + super().__init__(**kwargs) self.set_width(self.width, stretch=True) self.set_height(self.height, stretch=True) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index a81705e7..93f0e31c 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -15,7 +15,10 @@ from manimlib.utils.color import get_colormap_list from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import list_update +from manimlib.utils.iterables import resize_array +from manimlib.utils.iterables import resize_preserving_order from manimlib.utils.bezier import interpolate +from manimlib.utils.bezier import set_array_by_interpolation from manimlib.utils.paths import straight_path from manimlib.utils.simple_functions import get_parameters from manimlib.utils.space_ops import angle_of_vector @@ -34,7 +37,7 @@ class Mobject(object): """ CONFIG = { "color": WHITE, - "dim": 3, + "dim": 3, # TODO, get rid of this # Lighting parameters # Positive gloss up to 1 makes it reflect the light. "gloss": 0.0, @@ -42,9 +45,6 @@ class Mobject(object): "shadow": 0.0, # For shaders "shader_folder": "", - # "vert_shader_file": "", - # "geom_shader_file": "", - # "frag_shader_file": "", "render_primitive": moderngl.TRIANGLE_STRIP, "texture_paths": None, "depth_test": False, @@ -62,8 +62,8 @@ class Mobject(object): self.parents = [] self.family = [self] - self.init_updaters() self.init_data() + self.init_updaters() self.init_points() self.init_colors() self.init_shader_data() @@ -84,35 +84,32 @@ class Mobject(object): # Typically implemented in subclass, unless purposefully left blank pass - # To sort out later + # Related to data dict def init_data(self): - self.data = np.zeros(0, dtype=self.shader_dtype) + self.data = { + "points": np.zeros((0, 3)), + } - def resize_data(self, new_length): - if new_length != len(self.data): - self.data = np.resize(self.data, new_length) + def set_data(self, data): + for key in data: + self.data[key] = data[key] + + def resize_points(self, new_length, resize_func=resize_array): + if new_length != len(self.data["points"]): + self.data["points"] = resize_func(self.data["points"], new_length) def set_points(self, points): - self.resize_data(len(points)) - self.data["point"] = points + self.resize_points(len(points)) + self.data["points"][:] = points def get_points(self): - return self.data["point"] - - def get_all_point_arrays(self): - return [self.data["point"]] - - def get_all_data_arrays(self): - return [self.data] - - def get_data_array_attrs(self): - return ["data"] + return self.data["points"] def clear_points(self): - self.resize_data(0) + self.resize_points(0) def get_num_points(self): - return len(self.data) + return len(self.data["points"]) # # Family matters @@ -212,8 +209,9 @@ class Mobject(object): copy_mobject = copy.copy(self) self.parents = parents - for attr in self.get_data_array_attrs(): - setattr(copy_mobject, attr, getattr(self, attr).copy()) + copy_mobject.data = dict(self.data) + for key in self.data: + copy_mobject.data[key] = self.data[key].copy() copy_mobject.submobjects = [] copy_mobject.add(*[sm.copy() for sm in self.submobjects]) copy_mobject.match_updaters(self) @@ -224,7 +222,7 @@ class Mobject(object): if isinstance(value, Mobject) and value in family and value is not self: setattr(copy_mobject, attr, value.copy()) if isinstance(value, np.ndarray): - setattr(copy_mobject, attr, np.array(value)) + setattr(copy_mobject, attr, value.copy()) if isinstance(value, ShaderWrapper): setattr(copy_mobject, attr, value.copy()) return copy_mobject @@ -343,8 +341,7 @@ class Mobject(object): def shift(self, *vectors): total_vector = reduce(op.add, vectors) for mob in self.get_family(): - for arr in mob.get_all_point_arrays(): - arr += total_vector + mob.set_points(mob.get_points() + total_vector) return self def scale(self, scale_factor, **kwargs): @@ -458,8 +455,8 @@ class Mobject(object): about_edge = ORIGIN about_point = self.get_bounding_box_point(about_edge) for mob in self.family_members_with_points(): - for arr in mob.get_all_point_arrays(): - arr[:] = func(arr - about_point) + about_point + points = mob.get_points() + points[:] = func(points - about_point) + about_point return self # Positioning methods @@ -512,8 +509,7 @@ class Mobject(object): else: aligner = self point_to_align = aligner.get_bounding_box_point(aligned_edge - direction) - self.shift((target_point - point_to_align + - buff * direction) * coor_mask) + self.shift((target_point - point_to_align + buff * direction) * coor_mask) return self def shift_onto_screen(self, **kwargs): @@ -807,11 +803,9 @@ class Mobject(object): else: return getattr(self, array_attr) - def get_all_points(self): # TODO, use get_all_point_arrays? + def get_all_points(self): if self.submobjects: - return np.vstack([ - sm.get_points() for sm in self.get_family() - ]) + return np.vstack([sm.get_points() for sm in self.get_family()]) else: return self.get_points() @@ -1107,20 +1101,25 @@ class Mobject(object): self.null_point_align(mobject) # Needed? self.align_submobjects(mobject) for mob1, mob2 in zip(self.get_family(), mobject.get_family()): + # Separate out how points are treated so that subclasses + # can handle that case differently if they choose mob1.align_points(mob2) + for key in mob1.data: + if key == "points": + continue + arr1 = mob1.data[key] + arr2 = mob2.data[key] + if len(arr2) > len(arr1): + mob1.data[key] = resize_preserving_order(arr1, len(arr2)) + elif len(arr1) > len(arr2): + mob2.data[key] = resize_preserving_order(arr2, len(arr1)) def align_points(self, mobject): - count1 = self.get_num_points() - count2 = mobject.get_num_points() - if count1 < count2: - self.align_points_with_larger(mobject) - elif count2 < count1: - mobject.align_points_with_larger(self) + max_len = max(self.get_num_points(), mobject.get_num_points()) + self.resize_points(max_len, resize_func=resize_preserving_order) + mobject.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_points_with_larger(self, larger_mobject): - raise Exception("Not implemented") - def align_submobjects(self, mobject): mob1 = self mob2 = mobject @@ -1188,19 +1187,16 @@ class Mobject(object): Turns self into an interpolation between mobject1 and mobject2. """ - mobs = [self, mobject1, mobject2] - # for arr, arr1, arr2 in zip(*(m.get_all_data_arrays() for m in mobs)): - # arr[:] = interpolate(arr1, arr2, alpha) - # if path_func is not straight_path: - for arr, arr1, arr2 in zip(*(m.get_all_point_arrays() for m in mobs)): - arr[:] = path_func(arr1, arr2, alpha) - # self.interpolate_color(mobject1, mobject2, alpha) + for key in self.data: + func = path_func if key == "points" else interpolate + self.data[key][:] = func( + mobject1.data[key], + mobject2.data[key], + alpha + ) self.interpolate_light_style(mobject1, mobject2, alpha) # TODO, interpolate uniforms instaed return self - def interpolate_color(self, mobject1, mobject2, alpha): - pass # To implement in subclass - def interpolate_light_style(self, mobject1, mobject2, alpha): g0 = self.get_gloss() g1 = mobject1.get_gloss() @@ -1236,8 +1232,7 @@ class Mobject(object): """ self.align_submobjects(mobject) for sm1, sm2 in zip(self.get_family(), mobject.get_family()): - for arr1, arr2 in zip(sm1.get_all_data_arrays(), sm2.get_all_data_arrays()): - arr1[:] = arr2 + sm1.set_data(sm2.data) return self def cleanup_from_animation(self): @@ -1320,9 +1315,10 @@ class Mobject(object): # For shader data def init_shader_data(self): + self.shader_data = np.zeros(len(self.get_points()), dtype=self.shader_dtype) self.shader_indices = None self.shader_wrapper = ShaderWrapper( - vert_data=self.data, + vert_data=self.shader_data, shader_folder=self.shader_folder, texture_paths=self.texture_paths, depth_test=self.depth_test, @@ -1333,18 +1329,18 @@ class Mobject(object): self.shader_wrapper.refresh_id() return self - # def get_blank_shader_data_array(self, size, name="data"): - # # If possible, try to populate an existing array, rather - # # than recreating it each frame - # arr = getattr(self, name) - # if arr.size != size: - # new_arr = np.resize(arr, size) - # setattr(self, name, new_arr) - # return new_arr - # return arr + def get_blank_shader_data_array(self, size, name="shader_data"): + # If possible, try to populate an existing array, rather + # than recreating it each frame + arr = getattr(self, name) + if arr.size != size: + new_arr = resize_array(arr, size) + setattr(self, name, new_arr) + return new_arr + return arr def get_shader_wrapper(self): - self.shader_wrapper.vert_data = self.data # TODO + self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.uniforms = self.get_shader_uniforms() self.shader_wrapper.depth_test = self.depth_test @@ -1367,6 +1363,10 @@ class Mobject(object): result.append(shader_wrapper) return result + def get_shader_data(self): + # May be different for subclasses + return self.shader_data + def get_shader_uniforms(self): return { "is_fixed_in_frame": float(self.is_fixed_in_frame), @@ -1411,7 +1411,7 @@ class Point(Mobject): return self.artificial_height def get_location(self): - return np.array(self.get_points()[0]) + return self.get_points()[0].copy() def get_bounding_box_point(self, *args, **kwargs): return self.get_location() diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 09d69f8d..3eaf86e0 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -281,6 +281,7 @@ class SVGMobject(VMobject): matrix[:, 1] *= -1 for mob in mobject.family_members_with_points(): + # TODO, directly apply matrix? mob.set_points(np.dot(mob.get_points(), matrix)) mobject.shift(x * RIGHT + y * UP) except: diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 7cf93f31..e9c1b80c 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -7,7 +7,7 @@ from manimlib.constants import ORIGIN from manimlib.mobject.types.point_cloud_mobject import PMobject from manimlib.mobject.geometry import DEFAULT_DOT_RADIUS from manimlib.utils.bezier import interpolate -from manimlib.utils.iterables import stretch_array_to_length +from manimlib.utils.iterables import resize_preserving_order class DotCloud(PMobject): @@ -32,7 +32,7 @@ class DotCloud(PMobject): def set_points(self, points): super().set_points(points) - self.radii = stretch_array_to_length(self.radii, len(points)) + self.radii = resize_preserving_order(self.radii, len(points)) return self def set_points_by_grid(self, n_rows, n_cols, height=None, width=None): @@ -58,7 +58,7 @@ class DotCloud(PMobject): if isinstance(radii, numbers.Number): self.radii[:] = radii else: - self.radii = stretch_array_to_length(radii, len(self.points)) + self.radii = resize_preserving_order(radii, len(self.points)) return self def scale(self, scale_factor, scale_radii=True, **kwargs): diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index d4a12997..c5fd8ac2 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -4,7 +4,7 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.color import color_gradient from manimlib.utils.color import color_to_rgba from manimlib.utils.color import rgba_to_color -from manimlib.utils.iterables import stretch_array_to_length +from manimlib.utils.iterables import resize_preserving_order class PMobject(Mobject): @@ -18,7 +18,7 @@ class PMobject(Mobject): def set_points(self, points): self.points = points - self.rgbas = stretch_array_to_length(self.rgbas, len(points)) + self.rgbas = resize_preserving_order(self.rgbas, len(points)) return self def add_points(self, points, rgbas=None, color=None, alpha=1): @@ -91,7 +91,7 @@ class PMobject(Mobject): return self def match_colors(self, pmobject): - self.rgbas[:] = stretch_array_to_length(pmobject.rgbas, len(self.points)) + self.rgbas[:] = resize_preserving_order(pmobject.rgbas, len(self.points)) return self def filter_out(self, condition): @@ -138,14 +138,6 @@ class PMobject(Mobject): return self.points[index] # Alignment - def align_points_with_larger(self, larger_mobject): - assert(isinstance(larger_mobject, PMobject)) - self.apply_over_attr_arrays( - lambda a: stretch_array_to_length( - a, larger_mobject.get_num_points() - ) - ) - def interpolate_color(self, mobject1, mobject2, alpha): self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha) return self diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index ec88f1d1..99f59849 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -44,7 +44,12 @@ class ParametricSurface(Mobject): self.uv_func = uv_func self.compute_triangle_indices() super().__init__(**kwargs) - self.sort_faces_back_to_front() + + def init_data(self): + self.data = { + "points": np.zeros((0, 3)), + "rgba": np.zeros((1, 4)), + } def init_points(self): dim = self.dim @@ -65,7 +70,10 @@ class ParametricSurface(Mobject): # infinitesimal nudged values alongside the original values. This way, one # can perform all the manipulations they'd like to the surface, and normals # are still easily recoverable. - self.points = np.vstack(point_lists) + self.set_points(np.vstack(point_lists)) + + def init_colors(self): + self.set_color(self.color, self.opacity) def compute_triangle_indices(self): # TODO, if there is an event which changes @@ -88,13 +96,10 @@ class ParametricSurface(Mobject): def get_triangle_indices(self): return self.triangle_indices - def init_colors(self): - self.rgbas = np.zeros((1, 4)) - self.set_color(self.color, self.opacity) - def get_surface_points_and_nudged_points(self): - k = len(self.points) // 3 - return self.points[:k], self.points[k:2 * k], self.points[2 * k:] + points = self.get_points() + k = len(points) // 3 + return points[:k], points[k:2 * k], points[2 * k:] def get_unit_normals(self): s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() @@ -104,40 +109,38 @@ class ParametricSurface(Mobject): ) return normalize_along_axis(normals, 1) - def set_color(self, color, opacity=1.0, family=True): + def set_color(self, color, opacity=None, family=True): # TODO, allow for multiple colors + if opacity is None: + opacity = self.data["rgba"][0, 3] rgba = color_to_rgba(color, opacity) mobs = self.get_family() if family else [self] for mob in mobs: - mob.rgbas[:] = rgba + mob.data["rgba"][:] = rgba return self def get_color(self): - return rgb_to_color(self.rgbas[0, :3]) + return rgb_to_color(self.data["rgba"][0, :3]) def set_opacity(self, opacity, family=True): mobs = self.get_family() if family else [self] for mob in mobs: - mob.rgbas[:, 3] = opacity - return self - - def interpolate_color(self, mobject1, mobject2, alpha): - self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha) + mob.data["rgba"][:, 3] = opacity return self def pointwise_become_partial(self, smobject, a, b, axis=None): if axis is None: axis = self.prefered_creation_axis assert(isinstance(smobject, ParametricSurface)) - self.points[:] = smobject.points[:] if a <= 0 and b >= 1: + self.set_points(smobject.points) return self nu, nv = smobject.resolution - self.points[:] = np.vstack([ + self.set_points(np.vstack([ self.get_partial_points_array(arr, a, b, (nu, nv, 3), axis=axis) - for arr in self.get_surface_points_and_nudged_points() - ]) + for arr in smobject.get_surface_points_and_nudged_points() + ])) return self def get_partial_points_array(self, points, a, b, resolution, axis): @@ -178,7 +181,8 @@ class ParametricSurface(Mobject): return data def fill_in_shader_color_info(self, data): - data["color"] = self.rgbas + # TODO, what if len(self.data["rgba"]) > 1? + data["color"] = self.data["rgba"] return data def get_shader_vert_indices(self): @@ -195,7 +199,7 @@ class SGroup(ParametricSurface): self.add(*parametric_surfaces) def init_points(self): - self.points = np.zeros((0, 3)) + pass # Needed? class TexturedSurface(ParametricSurface): @@ -229,40 +233,40 @@ class TexturedSurface(ParametricSurface): self.u_range = uv_surface.u_range self.v_range = uv_surface.v_range self.resolution = uv_surface.resolution + self.gloss = self.uv_surface.gloss super().__init__(self.uv_func, **kwargs) - def init_points(self): - self.points = self.uv_surface.points - # Init im_coords + def init_data(self): nu, nv = self.uv_surface.resolution - u_range = np.linspace(0, 1, nu) - v_range = np.linspace(1, 0, nv) # Reverse y-direction - uv_grid = np.array([[u, v] for u in u_range for v in v_range]) - self.im_coords = uv_grid + self.data = { + "points": self.uv_surface.get_points(), + "im_coords": np.array([ + [u, v] + for u in np.linspace(0, 1, nu) + for v in np.linspace(1, 0, nv) # Reverse y-direction + ]), + "opacity": np.array([self.uv_surface.data["rgba"][:, 3]]), + } def init_colors(self): - self.opacity = self.uv_surface.rgbas[:, 3] - self.gloss = self.uv_surface.gloss - - def interpolate_color(self, mobject1, mobject2, alpha): - # TODO, handle multiple textures - self.opacity = interpolate(mobject1.opacity, mobject2.opacity, alpha) - return self + pass def set_opacity(self, opacity, family=True): - self.opacity = opacity - if family: - for sm in self.submobjects: - sm.set_opacity(opacity, family) + mobs = self.get_family() if family else [self] + for mob in mobs: + mob.data["opacity"][:] = opacity return self def pointwise_become_partial(self, tsmobject, a, b, axis=1): super().pointwise_become_partial(tsmobject, a, b, axis) - self.im_coords[:] = tsmobject.im_coords + im_coords = self.data["im_coords"] + im_coords[:] = tsmobject.data["im_coords"] if a <= 0 and b >= 1: return self nu, nv = tsmobject.resolution - self.im_coords[:] = self.get_partial_points_array(self.im_coords, a, b, (nu, nv, 2), axis) + im_coords[:] = self.get_partial_points_array( + im_coords, a, b, (nu, nv, 2), axis + ) return self def get_shader_uniforms(self): @@ -271,6 +275,6 @@ class TexturedSurface(ParametricSurface): return result def fill_in_shader_color_info(self, data): - data["im_coords"] = self.im_coords - data["opacity"] = self.opacity + data["im_coords"] = self.data["im_coords"] + data["opacity"] = self.data["opacity"] return data diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 002c32c2..b291a8eb 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -2,7 +2,6 @@ import itertools as it import operator as op import moderngl -from colour import Color from functools import reduce from manimlib.constants import * @@ -17,10 +16,12 @@ from manimlib.utils.bezier import set_array_by_interpolation from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import partial_quadratic_bezier_points from manimlib.utils.color import color_to_rgba +from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.iterables import make_even -from manimlib.utils.iterables import stretch_array_to_length -from manimlib.utils.iterables import stretch_array_to_length_with_interpolation +from manimlib.utils.iterables import resize_array +from manimlib.utils.iterables import resize_preserving_order +from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import listify from manimlib.utils.paths import straight_path from manimlib.utils.space_ops import angle_between_vectors @@ -77,68 +78,27 @@ class VMobject(Mobject): def __init__(self, **kwargs): self.unit_normal_locked = False - self.triangulation_locked = False + self.needs_new_triangulation = True super().__init__(**kwargs) self.lock_unit_normal(family=False) - self.lock_triangulation(family=False) def get_group_class(self): return VGroup - # To sort out later def init_data(self): - self.fill_data = np.zeros(0, dtype=self.fill_dtype) - self.stroke_data = np.zeros(0, dtype=self.stroke_dtype) - - def resize_data(self, new_length): - self.stroke_data = np.resize(self.stroke_data, new_length) - self.fill_data = np.resize(self.fill_data, new_length) - self.fill_data["vert_index"][:, 0] = range(new_length) + self.data = { + "points": np.zeros((0, 3)), + "fill_rgba": np.zeros((1, 4)), + "stroke_rgba": np.zeros((1, 4)), + "stroke_width": np.zeros((1, 1)), + } def set_points(self, points): - if len(points) != len(self.stroke_data): - self.resize_data(len(points)) - - nppc = self.n_points_per_curve - self.stroke_data["point"] = points - self.stroke_data["prev_point"][:nppc] = points[-nppc:] - self.stroke_data["prev_point"][nppc:] = points[:-nppc] - self.stroke_data["next_point"][:-nppc] = points[nppc:] - self.stroke_data["next_point"][-nppc:] = points[:nppc] - - self.fill_data["point"] = points - - # # TODO, only do conditionally - # unit_normal = self.get_unit_normal() - # self.stroke_data["unit_normal"] = unit_normal - # self.fill_data["unit_normal"] = unit_normal - - # self.refresh_triangulation() - - def get_points(self): - return self.stroke_data["point"] - - def get_all_point_arrays(self): - return [ - self.fill_data["point"], - self.stroke_data["point"], - self.stroke_data["prev_point"], - self.stroke_data["next_point"], - ] - - def get_all_data_arrays(self): - return [self.fill_data, self.stroke_data] - - def get_data_array_attrs(self): - return ["fill_data", "stroke_data"] - - def get_num_points(self): - return len(self.stroke_data) + super().set_points(points) + self.refresh_triangulation() # Colors def init_colors(self): - self.fill_rgbas = np.zeros((1, 4)) - self.stroke_rgbas = np.zeros((1, 4)) self.set_fill( color=self.fill_color or self.color, opacity=self.fill_opacity, @@ -153,69 +113,45 @@ class VMobject(Mobject): self.set_flat_stroke(self.flat_stroke) return self - def generate_rgba_array(self, color, opacity): - """ - First arg can be either a color, or a tuple/list of colors. - Likewise, opacity can either be a float, or a tuple of floats. - """ - colors = listify(color) - opacities = listify(opacity) - return np.array([ - color_to_rgba(c, o) - for c, o in zip(*make_even(colors, opacities)) - ]) - - def update_rgbas_array(self, array_name, color, opacity): - rgbas = self.generate_rgba_array(color or BLACK, opacity or 0) - # Match up current rgbas array with the newly calculated - # one. 99% of the time they'll be the same. - curr_rgbas = getattr(self, array_name) - if len(curr_rgbas) < len(rgbas): - curr_rgbas = stretch_array_to_length(curr_rgbas, len(rgbas)) - setattr(self, array_name, curr_rgbas) - elif len(rgbas) < len(curr_rgbas): - rgbas = stretch_array_to_length(rgbas, len(curr_rgbas)) - # Only update rgb if color was not None, and only - # update alpha channel if opacity was passed in - if color is not None: - curr_rgbas[:, :3] = rgbas[:, :3] - if opacity is not None: - curr_rgbas[:, 3] = rgbas[:, 3] + def set_rgba_array(self, name, color, opacity, family=True): + # TODO, account for if color or opacity are tuples + rgb = color_to_rgb(color) if color else None + mobs = self.get_family() if family else [self] + for mob in mobs: + if rgb is not None: + mob.data[name][:, :3] = rgb + if opacity is not None: + mob.data[name][:, 3] = opacity return self def set_fill(self, color=None, opacity=None, family=True): - if family: - for sm in self.submobjects: - sm.set_fill(color, opacity, family) - self.update_rgbas_array("fill_rgbas", color, opacity) - return self + self.set_rgba_array('fill_rgba', color, opacity, family) - def set_stroke(self, color=None, width=None, opacity=None, - background=None, family=True): - if family: - for sm in self.submobjects: - sm.set_stroke(color, width, opacity, background, family) - self.update_rgbas_array("stroke_rgbas", color, opacity) - if width is not None: - self.stroke_width = np.array(listify(width), dtype=float) - if background is not None: - self.draw_stroke_behind_fill = background + def set_stroke(self, color=None, width=None, opacity=None, background=None, family=True): + self.set_rgba_array('stroke_rgba', color, opacity, family) + + mobs = self.get_family() if family else [self] + for mob in mobs: + if width is not None: + # TODO, account for if width is an array + mob.data['stroke_width'][:] = width + if background is not None: + mob.draw_stroke_behind_fill = background return self def set_style(self, fill_color=None, fill_opacity=None, - fill_rgbas=None, + fill_rgba=None, stroke_color=None, stroke_opacity=None, - stroke_rgbas=None, + stroke_rgba=None, stroke_width=None, gloss=None, shadow=None, - background_image_file=None, family=True): - if fill_rgbas is not None: - self.fill_rgbas = np.array(fill_rgbas) + if fill_rgba is not None: + self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba)) else: self.set_fill( color=fill_color, @@ -223,10 +159,9 @@ class VMobject(Mobject): family=family ) - if stroke_rgbas is not None: - self.stroke_rgbas = np.array(stroke_rgbas) - if stroke_width is not None: - self.stroke_width = np.array(listify(stroke_width)) + if stroke_rgba is not None: + self.data['stroke_rgba'] = resize_with_interpolation(stroke_rgba, len(fill_rgba)) + self.set_stroke(width=stroke_width) else: self.set_stroke( color=stroke_color, @@ -239,31 +174,19 @@ class VMobject(Mobject): self.set_gloss(gloss, family=family) if shadow is not None: self.set_shadow(shadow, family=family) - if background_image_file: - self.color_using_background_image(background_image_file) return self def get_style(self): return { - "fill_rgbas": self.get_fill_rgbas(), - "stroke_rgbas": self.get_stroke_rgbas(), - "stroke_width": self.stroke_width, + "fill_rgba": self.data['fill_rgba'], + "stroke_rgba": self.data['stroke_rgba'], + "stroke_width": self.data['stroke_width'], "gloss": self.get_gloss(), "shadow": self.get_shadow(), - "background_image_file": self.get_background_image_file(), } def match_style(self, vmobject, family=True): - for name, value in vmobject.get_style().items(): - if isinstance(value, np.ndarray): - curr = getattr(self, name) - if curr.size == value.size: - curr[:] = value[:] - else: - setattr(self, name, np.array(value)) - else: - setattr(self, name, value) - + self.set_style(**vmobject.get_style(), family=False) if family: # Does its best to match up submobject lists, and # match styles accordingly @@ -299,12 +222,29 @@ class VMobject(Mobject): super().fade(darkness, family) return self - def get_fill_rgbas(self): - try: - return self.fill_rgbas - except AttributeError: - return np.zeros((1, 4)) + def get_fill_colors(self): + return [ + rgb_to_hex(rgba[:3]) + for rgba in self.data['fill_rgba'] + ] + def get_fill_opacities(self): + return self.data['fill_rgba'][:, 3] + + def get_stroke_colors(self): + return [ + rgb_to_hex(rgba[:3]) + for rgba in self.data['stroke_rgba'] + ] + + def get_stroke_opacities(self): + return self.data['stroke_rgba'][:, 3] + + def get_stroke_widths(self): + return self.data['stroke_width'] + + # TODO, it's weird for these to return the first of various lists + # rather than the full information def get_fill_color(self): """ If there are multiple colors (for gradient) @@ -319,62 +259,25 @@ class VMobject(Mobject): """ return self.get_fill_opacities()[0] - def get_fill_colors(self): - return [ - Color(rgb=rgba[:3]) - for rgba in self.get_fill_rgbas() - ] - - def get_fill_opacities(self): - return self.get_fill_rgbas()[:, 3] - - def get_stroke_rgbas(self): - try: - return self.stroke_rgbas - except AttributeError: - return np.zeros((1, 4)) - - # TODO, it's weird for these to return the first of various lists - # rather than the full information def get_stroke_color(self): return self.get_stroke_colors()[0] def get_stroke_width(self): - return self.stroke_width[0] + return self.get_stroke_widths()[0] def get_stroke_opacity(self): return self.get_stroke_opacities()[0] - def get_stroke_colors(self): - return [ - rgb_to_hex(rgba[:3]) - for rgba in self.get_stroke_rgbas() - ] - - def get_stroke_opacities(self): - return self.get_stroke_rgbas()[:, 3] - def get_color(self): - if np.all(self.get_fill_opacities() == 0): + if self.has_stroke(): return self.get_stroke_color() return self.get_fill_color() def has_stroke(self): - if len(self.stroke_width) == 1: - if self.stroke_width == 0: - return False - elif not self.stroke_width.any(): - return False - alphas = self.stroke_rgbas[:, 3] - if len(alphas) == 1: - return alphas[0] > 0 - return alphas.any() + return any(self.get_stroke_widths()) and any(self.get_stroke_opacities()) def has_fill(self): - alphas = self.fill_rgbas[:, 3] - if len(alphas) == 1: - return alphas[0] > 0 - return alphas.any() + return any(self.get_fill_opacities()) def get_opacity(self): if self.has_fill(): @@ -382,45 +285,14 @@ class VMobject(Mobject): return self.get_stroke_opacity() def set_flat_stroke(self, flat_stroke=True, family=True): - self.flat_stroke = flat_stroke - if family: - for submob in self.submobjects: - submob.set_flat_stroke(flat_stroke, family) + mobs = self.get_family() if family else [self] + for mob in mobs: + mob.flat_stroke = flat_stroke return self def get_flat_stroke(self): return self.flat_stroke - # TODO, this currently does nothing - def color_using_background_image(self, background_image_file): - self.background_image_file = background_image_file - self.set_color(WHITE) - for submob in self.submobjects: - submob.color_using_background_image(background_image_file) - return self - - def get_background_image_file(self): - return self.background_image_file - - def match_background_image_file(self, vmobject): - self.color_using_background_image(vmobject.get_background_image_file()) - return self - - def stretched_style_array_matching_points(self, array): - new_len = self.get_num_points() - long_arr = stretch_array_to_length_with_interpolation( - array, 1 + 2 * (new_len // 3) - ) - shape = array.shape - if len(shape) > 1: - result = np.zeros((new_len, shape[1])) - else: - result = np.zeros(new_len) - result[0::3] = long_arr[0:-1:2] - result[1::3] = long_arr[1::2] - result[2::3] = long_arr[2::2] - return result - # Points def set_anchors_and_handles(self, anchors1, handles, anchors2): assert(len(anchors1) == len(handles) == len(anchors2)) @@ -436,7 +308,8 @@ class VMobject(Mobject): # TODO, check that number new points is a multiple of 4? # or else that if self.get_num_points() % 4 == 1, then # len(new_points) % 4 == 3? - self.set_points(np.vstack([self.get_points(), new_points])) + self.resize_points(self.get_num_points() + len(new_points)) + self.data["points"][-len(new_points):] = new_points return self def start_new_path(self, point): @@ -784,7 +657,6 @@ class VMobject(Mobject): # Alignment def align_points(self, vmobject): - self.align_rgbas(vmobject) if self.get_num_points() == len(vmobject.get_points()): return @@ -871,39 +743,20 @@ class VMobject(Mobject): new_points += partial_quadratic_bezier_points(group, a1, a2) return np.vstack(new_points) - def align_rgbas(self, vmobject): - attrs = ["fill_rgbas", "stroke_rgbas"] - for attr in attrs: - a1 = getattr(self, attr) - a2 = getattr(vmobject, attr) - if len(a1) > len(a2): - new_a2 = stretch_array_to_length(a2, len(a1)) - setattr(vmobject, attr, new_a2) - elif len(a2) > len(a1): - new_a1 = stretch_array_to_length(a1, len(a2)) - setattr(self, attr, new_a1) + def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs): + super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) + if self.has_fill(): + tri1 = mobject1.get_triangulation() + tri2 = mobject2.get_triangulation() + if len(tri1) != len(tri1) or not all(tri1 == tri2): + self.refresh_triangulation() return self - def interpolate_color(self, mobject1, mobject2, alpha): - attrs = [ - "fill_rgbas", - "stroke_rgbas", - "stroke_width", - ] - for attr in attrs: - set_array_by_interpolation( - getattr(self, attr), - getattr(mobject1, attr), - getattr(mobject2, attr), - alpha - ) - def pointwise_become_partial(self, vmobject, a, b): assert(isinstance(vmobject, VMobject)) - self.set_points(vmobject.get_points()) if a <= 0 and b >= 1: return self - num_curves = self.get_num_curves() + num_curves = vmobject.get_num_curves() nppc = self.n_points_per_curve # Partial curve includes three portions: @@ -918,26 +771,26 @@ class VMobject(Mobject): i3 = nppc * upper_index i4 = nppc * (upper_index + 1) - points = self.get_points() vm_points = vmobject.get_points() + new_points = vm_points.copy() if num_curves == 0: - points[:] = 0 + new_points[:] = 0 return self if lower_index == upper_index: tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, upper_residue) - points[:i1] = tup[0] - points[i1:i4] = tup - points[i4:] = tup[2] - points[nppc:] = points[nppc - 1] + new_points[:i1] = tup[0] + new_points[i1:i4] = tup + new_points[i4:] = tup[2] + new_points[nppc:] = new_points[nppc - 1] else: low_tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, 1) high_tup = partial_quadratic_bezier_points(vm_points[i3:i4], 0, upper_residue) - points[0:i1] = low_tup[0] - points[i1:i2] = low_tup - # Keep points i2:i3 as they are - points[i3:i4] = high_tup - points[i4:] = high_tup[2] - self.set_points(points) + new_points[0:i1] = low_tup[0] + new_points[i1:i2] = low_tup + # Keep new_points i2:i3 as they are + new_points[i3:i4] = high_tup + new_points[i4:] = high_tup[2] + self.set_points(new_points) return self def get_subcurve(self, a, b): @@ -945,14 +798,10 @@ class VMobject(Mobject): vmob.pointwise_become_partial(self, a, b) return vmob - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path): - super().interpolate(mobject1, mobject2, alpha, path_func) - if not np.all(mobject1.get_triangulation() == mobject2.get_triangulation()): - self.refresh_triangulation() - return self - # For shaders def init_shader_data(self): + self.fill_data = np.zeros(0, dtype=self.fill_dtype) + self.stroke_data = np.zeros(0, dtype=self.stroke_dtype) self.fill_shader_wrapper = ShaderWrapper( vert_data=self.fill_data, vert_indices=np.zeros(0, dtype='i4'), @@ -1019,39 +868,27 @@ class VMobject(Mobject): return result def get_stroke_shader_data(self): - # TODO, make even simpler after fixing colors - rgbas = self.get_stroke_rgbas() - if len(rgbas) > 1: - rgbas = self.stretched_style_array_matching_points(rgbas) + points = self.get_points() + if len(self.stroke_data) != len(points): + self.stroke_data = resize_array(self.stroke_data, len(points)) + # TODO, account for when self.data["stroke_width"] and self.data["stroke_rgba"] + # have length greater than 1 - stroke_width = self.stroke_width - if len(stroke_width) > 1: - stroke_width = self.stretched_style_array_matching_points(stroke_width) + nppc = self.n_points_per_curve + self.stroke_data["point"] = points + self.stroke_data["prev_point"][:nppc] = points[-nppc:] + self.stroke_data["prev_point"][nppc:] = points[:-nppc] + self.stroke_data["next_point"][:-nppc] = points[nppc:] + self.stroke_data["next_point"][-nppc:] = points[:nppc] - data = self.stroke_data - data["stroke_width"][:, 0] = stroke_width - data["color"] = rgbas - return data - - def lock_triangulation(self, family=True): - mobs = self.get_family() if family else [self] - for mob in mobs: - mob.triangulation_locked = False - mob.saved_triangulation = mob.get_triangulation() - mob.triangulation_locked = True - return self - - def unlock_triangulation(self): - for sm in self.get_family(): - sm.triangulation_locked = False - return self + self.stroke_data["unit_normal"] = self.get_unit_normal() + self.stroke_data["stroke_width"] = self.data["stroke_width"] + self.stroke_data["color"] = self.data["stroke_rgba"] + return self.stroke_data def refresh_triangulation(self): for mob in self.get_family(): - if mob.triangulation_locked: - mob.triangulation_locked = False - mob.saved_triangulation = mob.get_triangulation() - mob.triangulation_locked = True + mob.needs_new_triangulation = True return self def get_triangulation(self, normal_vector=None): @@ -1061,13 +898,14 @@ class VMobject(Mobject): if normal_vector is None: normal_vector = self.get_unit_normal() - if self.triangulation_locked: - return self.saved_triangulation + if not self.needs_new_triangulation: + return self.saved_traignulation points = self.get_points() if len(points) <= 1: - return np.zeros(0, dtype='i4') + self.saved_traignulation == np.zeros(0, dtype='i4') + return self.saved_traignulation # Rotate points such that unit normal vector is OUT # TODO, 99% of the time this does nothing. Do a check for that? @@ -1104,14 +942,20 @@ class VMobject(Mobject): inner_tri_indices = inner_vert_indices[earclip_triangulation(inner_verts, rings)] tri_indices = np.hstack([indices, inner_tri_indices]) + self.saved_traignulation = tri_indices + self.needs_new_triangulation = False return tri_indices def get_fill_shader_data(self): - # TODO, make simpler - rgbas = self.get_fill_rgbas()[:1] - data = self.fill_data - data["color"] = rgbas - return data + points = self.get_points() + if len(self.fill_data) != len(points): + self.fill_data = resize_array(self.fill_data, len(points)) + self.fill_data["vert_index"][:, 0] = range(len(points)) + + self.fill_data["point"] = points + self.fill_data["unit_normal"] = self.get_unit_normal() + self.fill_data["color"] = self.data["fill_rgba"] + return self.fill_data def get_fill_shader_vert_indices(self): return self.get_triangulation() @@ -1121,11 +965,11 @@ class VGroup(VMobject): def __init__(self, *vmobjects, **kwargs): if not all([isinstance(m, VMobject) for m in vmobjects]): raise Exception("All submobjects must be of type VMobject") - VMobject.__init__(self, **kwargs) + super().__init__(**kwargs) self.add(*vmobjects) -class VectorizedPoint(VMobject, Point): +class VectorizedPoint(Point, VMobject): CONFIG = { "color": BLACK, "fill_opacity": 0, @@ -1135,13 +979,13 @@ class VectorizedPoint(VMobject, Point): } def __init__(self, location=ORIGIN, **kwargs): - VMobject.__init__(self, **kwargs) + super().__init__(**kwargs) self.set_points(np.array([location])) class CurvesAsSubmobjects(VGroup): def __init__(self, vmobject, **kwargs): - VGroup.__init__(self, **kwargs) + super().__init__(**kwargs) for tup in vmobject.get_bezier_tuples(): part = VMobject() part.set_points(tup) @@ -1157,7 +1001,7 @@ class DashedVMobject(VMobject): } def __init__(self, vmobject, **kwargs): - VMobject.__init__(self, **kwargs) + super().__init__(**kwargs) num_dashes = self.num_dashes ps_ratio = self.positive_space_ratio if num_dashes > 0: diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 50feefa4..575bf91c 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -74,8 +74,8 @@ def interpolate(start, end, alpha): sys.exit(2) -def set_array_by_interpolation(arr, arr1, arr2, alpha): - arr[:] = interpolate(arr1, arr2, alpha) +def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate): + arr[:] = interp_func(arr1, arr2, alpha) return arr diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 4118642c..155fa7f2 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -80,15 +80,23 @@ def listify(obj): return [obj] -def stretch_array_to_length(nparray, length): - # TODO, rename to "resize"? +def resize_array(nparray, length): + return np.resize(nparray, (length, *nparray.shape[1:])) + + +def resize_preserving_order(nparray, length): + if len(nparray) == 0: + return np.zeros((length, *nparray.shape[1:])) + if len(nparray) == length: + return nparray indices = np.arange(length) * len(nparray) // length return nparray[indices] -def stretch_array_to_length_with_interpolation(nparray, length): - curr_len = len(nparray) - cont_indices = np.linspace(0, curr_len - 1, length) +def resize_with_interpolation(nparray, length): + if len(nparray) == length: + return nparray + cont_indices = np.linspace(0, len(nparray) - 1, length) return np.array([ (1 - a) * nparray[lh] + a * nparray[rh] for ci in cont_indices