Reframe Mobject, VMobject and SurfaceMobject with a data map

This commit is contained in:
Grant Sanderson
2021-01-11 10:57:23 -10:00
parent b3335c65fb
commit 9314dfd933
10 changed files with 276 additions and 428 deletions

View File

@ -24,7 +24,7 @@ class Rotating(Animation):
def interpolate_mobject(self, alpha): def interpolate_mobject(self, alpha):
for sm1, sm2 in self.get_all_families_zipped(): for sm1, sm2 in self.get_all_families_zipped():
sm1.points[:] = sm2.points sm1.set_points(sm2.get_points())
self.mobject.rotate( self.mobject.rotate(
alpha * self.angle, alpha * self.angle,
axis=self.axis, axis=self.axis,

View File

@ -310,8 +310,7 @@ class Dot(Circle):
} }
def __init__(self, point=ORIGIN, **kwargs): def __init__(self, point=ORIGIN, **kwargs):
Circle.__init__(self, arc_center=point, **kwargs) super().__init__(arc_center=point, **kwargs)
self.lock_triangulation()
class SmallDot(Dot): class SmallDot(Dot):
@ -327,7 +326,7 @@ class Ellipse(Circle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
Circle.__init__(self, **kwargs) super().__init__(**kwargs)
self.set_width(self.width, stretch=True) self.set_width(self.width, stretch=True)
self.set_height(self.height, stretch=True) self.set_height(self.height, stretch=True)

View File

@ -15,7 +15,10 @@ from manimlib.utils.color import get_colormap_list
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import batch_by_property
from manimlib.utils.iterables import list_update from manimlib.utils.iterables import list_update
from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import resize_preserving_order
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import set_array_by_interpolation
from manimlib.utils.paths import straight_path from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
@ -34,7 +37,7 @@ class Mobject(object):
""" """
CONFIG = { CONFIG = {
"color": WHITE, "color": WHITE,
"dim": 3, "dim": 3, # TODO, get rid of this
# Lighting parameters # Lighting parameters
# Positive gloss up to 1 makes it reflect the light. # Positive gloss up to 1 makes it reflect the light.
"gloss": 0.0, "gloss": 0.0,
@ -42,9 +45,6 @@ class Mobject(object):
"shadow": 0.0, "shadow": 0.0,
# For shaders # For shaders
"shader_folder": "", "shader_folder": "",
# "vert_shader_file": "",
# "geom_shader_file": "",
# "frag_shader_file": "",
"render_primitive": moderngl.TRIANGLE_STRIP, "render_primitive": moderngl.TRIANGLE_STRIP,
"texture_paths": None, "texture_paths": None,
"depth_test": False, "depth_test": False,
@ -62,8 +62,8 @@ class Mobject(object):
self.parents = [] self.parents = []
self.family = [self] self.family = [self]
self.init_updaters()
self.init_data() self.init_data()
self.init_updaters()
self.init_points() self.init_points()
self.init_colors() self.init_colors()
self.init_shader_data() self.init_shader_data()
@ -84,35 +84,32 @@ class Mobject(object):
# Typically implemented in subclass, unless purposefully left blank # Typically implemented in subclass, unless purposefully left blank
pass pass
# To sort out later # Related to data dict
def init_data(self): def init_data(self):
self.data = np.zeros(0, dtype=self.shader_dtype) self.data = {
"points": np.zeros((0, 3)),
}
def resize_data(self, new_length): def set_data(self, data):
if new_length != len(self.data): for key in data:
self.data = np.resize(self.data, new_length) self.data[key] = data[key]
def resize_points(self, new_length, resize_func=resize_array):
if new_length != len(self.data["points"]):
self.data["points"] = resize_func(self.data["points"], new_length)
def set_points(self, points): def set_points(self, points):
self.resize_data(len(points)) self.resize_points(len(points))
self.data["point"] = points self.data["points"][:] = points
def get_points(self): def get_points(self):
return self.data["point"] return self.data["points"]
def get_all_point_arrays(self):
return [self.data["point"]]
def get_all_data_arrays(self):
return [self.data]
def get_data_array_attrs(self):
return ["data"]
def clear_points(self): def clear_points(self):
self.resize_data(0) self.resize_points(0)
def get_num_points(self): def get_num_points(self):
return len(self.data) return len(self.data["points"])
# #
# Family matters # Family matters
@ -212,8 +209,9 @@ class Mobject(object):
copy_mobject = copy.copy(self) copy_mobject = copy.copy(self)
self.parents = parents self.parents = parents
for attr in self.get_data_array_attrs(): copy_mobject.data = dict(self.data)
setattr(copy_mobject, attr, getattr(self, attr).copy()) for key in self.data:
copy_mobject.data[key] = self.data[key].copy()
copy_mobject.submobjects = [] copy_mobject.submobjects = []
copy_mobject.add(*[sm.copy() for sm in self.submobjects]) copy_mobject.add(*[sm.copy() for sm in self.submobjects])
copy_mobject.match_updaters(self) copy_mobject.match_updaters(self)
@ -224,7 +222,7 @@ class Mobject(object):
if isinstance(value, Mobject) and value in family and value is not self: if isinstance(value, Mobject) and value in family and value is not self:
setattr(copy_mobject, attr, value.copy()) setattr(copy_mobject, attr, value.copy())
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
setattr(copy_mobject, attr, np.array(value)) setattr(copy_mobject, attr, value.copy())
if isinstance(value, ShaderWrapper): if isinstance(value, ShaderWrapper):
setattr(copy_mobject, attr, value.copy()) setattr(copy_mobject, attr, value.copy())
return copy_mobject return copy_mobject
@ -343,8 +341,7 @@ class Mobject(object):
def shift(self, *vectors): def shift(self, *vectors):
total_vector = reduce(op.add, vectors) total_vector = reduce(op.add, vectors)
for mob in self.get_family(): for mob in self.get_family():
for arr in mob.get_all_point_arrays(): mob.set_points(mob.get_points() + total_vector)
arr += total_vector
return self return self
def scale(self, scale_factor, **kwargs): def scale(self, scale_factor, **kwargs):
@ -458,8 +455,8 @@ class Mobject(object):
about_edge = ORIGIN about_edge = ORIGIN
about_point = self.get_bounding_box_point(about_edge) about_point = self.get_bounding_box_point(about_edge)
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
for arr in mob.get_all_point_arrays(): points = mob.get_points()
arr[:] = func(arr - about_point) + about_point points[:] = func(points - about_point) + about_point
return self return self
# Positioning methods # Positioning methods
@ -512,8 +509,7 @@ class Mobject(object):
else: else:
aligner = self aligner = self
point_to_align = aligner.get_bounding_box_point(aligned_edge - direction) point_to_align = aligner.get_bounding_box_point(aligned_edge - direction)
self.shift((target_point - point_to_align + self.shift((target_point - point_to_align + buff * direction) * coor_mask)
buff * direction) * coor_mask)
return self return self
def shift_onto_screen(self, **kwargs): def shift_onto_screen(self, **kwargs):
@ -807,11 +803,9 @@ class Mobject(object):
else: else:
return getattr(self, array_attr) return getattr(self, array_attr)
def get_all_points(self): # TODO, use get_all_point_arrays? def get_all_points(self):
if self.submobjects: if self.submobjects:
return np.vstack([ return np.vstack([sm.get_points() for sm in self.get_family()])
sm.get_points() for sm in self.get_family()
])
else: else:
return self.get_points() return self.get_points()
@ -1107,20 +1101,25 @@ class Mobject(object):
self.null_point_align(mobject) # Needed? self.null_point_align(mobject) # Needed?
self.align_submobjects(mobject) self.align_submobjects(mobject)
for mob1, mob2 in zip(self.get_family(), mobject.get_family()): for mob1, mob2 in zip(self.get_family(), mobject.get_family()):
# Separate out how points are treated so that subclasses
# can handle that case differently if they choose
mob1.align_points(mob2) mob1.align_points(mob2)
for key in mob1.data:
if key == "points":
continue
arr1 = mob1.data[key]
arr2 = mob2.data[key]
if len(arr2) > len(arr1):
mob1.data[key] = resize_preserving_order(arr1, len(arr2))
elif len(arr1) > len(arr2):
mob2.data[key] = resize_preserving_order(arr2, len(arr1))
def align_points(self, mobject): def align_points(self, mobject):
count1 = self.get_num_points() max_len = max(self.get_num_points(), mobject.get_num_points())
count2 = mobject.get_num_points() self.resize_points(max_len, resize_func=resize_preserving_order)
if count1 < count2: mobject.resize_points(max_len, resize_func=resize_preserving_order)
self.align_points_with_larger(mobject)
elif count2 < count1:
mobject.align_points_with_larger(self)
return self return self
def align_points_with_larger(self, larger_mobject):
raise Exception("Not implemented")
def align_submobjects(self, mobject): def align_submobjects(self, mobject):
mob1 = self mob1 = self
mob2 = mobject mob2 = mobject
@ -1188,19 +1187,16 @@ class Mobject(object):
Turns self into an interpolation between mobject1 Turns self into an interpolation between mobject1
and mobject2. and mobject2.
""" """
mobs = [self, mobject1, mobject2] for key in self.data:
# for arr, arr1, arr2 in zip(*(m.get_all_data_arrays() for m in mobs)): func = path_func if key == "points" else interpolate
# arr[:] = interpolate(arr1, arr2, alpha) self.data[key][:] = func(
# if path_func is not straight_path: mobject1.data[key],
for arr, arr1, arr2 in zip(*(m.get_all_point_arrays() for m in mobs)): mobject2.data[key],
arr[:] = path_func(arr1, arr2, alpha) alpha
# self.interpolate_color(mobject1, mobject2, alpha) )
self.interpolate_light_style(mobject1, mobject2, alpha) # TODO, interpolate uniforms instaed self.interpolate_light_style(mobject1, mobject2, alpha) # TODO, interpolate uniforms instaed
return self return self
def interpolate_color(self, mobject1, mobject2, alpha):
pass # To implement in subclass
def interpolate_light_style(self, mobject1, mobject2, alpha): def interpolate_light_style(self, mobject1, mobject2, alpha):
g0 = self.get_gloss() g0 = self.get_gloss()
g1 = mobject1.get_gloss() g1 = mobject1.get_gloss()
@ -1236,8 +1232,7 @@ class Mobject(object):
""" """
self.align_submobjects(mobject) self.align_submobjects(mobject)
for sm1, sm2 in zip(self.get_family(), mobject.get_family()): for sm1, sm2 in zip(self.get_family(), mobject.get_family()):
for arr1, arr2 in zip(sm1.get_all_data_arrays(), sm2.get_all_data_arrays()): sm1.set_data(sm2.data)
arr1[:] = arr2
return self return self
def cleanup_from_animation(self): def cleanup_from_animation(self):
@ -1320,9 +1315,10 @@ class Mobject(object):
# For shader data # For shader data
def init_shader_data(self): def init_shader_data(self):
self.shader_data = np.zeros(len(self.get_points()), dtype=self.shader_dtype)
self.shader_indices = None self.shader_indices = None
self.shader_wrapper = ShaderWrapper( self.shader_wrapper = ShaderWrapper(
vert_data=self.data, vert_data=self.shader_data,
shader_folder=self.shader_folder, shader_folder=self.shader_folder,
texture_paths=self.texture_paths, texture_paths=self.texture_paths,
depth_test=self.depth_test, depth_test=self.depth_test,
@ -1333,18 +1329,18 @@ class Mobject(object):
self.shader_wrapper.refresh_id() self.shader_wrapper.refresh_id()
return self return self
# def get_blank_shader_data_array(self, size, name="data"): def get_blank_shader_data_array(self, size, name="shader_data"):
# # If possible, try to populate an existing array, rather # If possible, try to populate an existing array, rather
# # than recreating it each frame # than recreating it each frame
# arr = getattr(self, name) arr = getattr(self, name)
# if arr.size != size: if arr.size != size:
# new_arr = np.resize(arr, size) new_arr = resize_array(arr, size)
# setattr(self, name, new_arr) setattr(self, name, new_arr)
# return new_arr return new_arr
# return arr return arr
def get_shader_wrapper(self): def get_shader_wrapper(self):
self.shader_wrapper.vert_data = self.data # TODO self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.uniforms = self.get_shader_uniforms() self.shader_wrapper.uniforms = self.get_shader_uniforms()
self.shader_wrapper.depth_test = self.depth_test self.shader_wrapper.depth_test = self.depth_test
@ -1367,6 +1363,10 @@ class Mobject(object):
result.append(shader_wrapper) result.append(shader_wrapper)
return result return result
def get_shader_data(self):
# May be different for subclasses
return self.shader_data
def get_shader_uniforms(self): def get_shader_uniforms(self):
return { return {
"is_fixed_in_frame": float(self.is_fixed_in_frame), "is_fixed_in_frame": float(self.is_fixed_in_frame),
@ -1411,7 +1411,7 @@ class Point(Mobject):
return self.artificial_height return self.artificial_height
def get_location(self): def get_location(self):
return np.array(self.get_points()[0]) return self.get_points()[0].copy()
def get_bounding_box_point(self, *args, **kwargs): def get_bounding_box_point(self, *args, **kwargs):
return self.get_location() return self.get_location()

View File

@ -281,6 +281,7 @@ class SVGMobject(VMobject):
matrix[:, 1] *= -1 matrix[:, 1] *= -1
for mob in mobject.family_members_with_points(): for mob in mobject.family_members_with_points():
# TODO, directly apply matrix?
mob.set_points(np.dot(mob.get_points(), matrix)) mob.set_points(np.dot(mob.get_points(), matrix))
mobject.shift(x * RIGHT + y * UP) mobject.shift(x * RIGHT + y * UP)
except: except:

View File

@ -7,7 +7,7 @@ from manimlib.constants import ORIGIN
from manimlib.mobject.types.point_cloud_mobject import PMobject from manimlib.mobject.types.point_cloud_mobject import PMobject
from manimlib.mobject.geometry import DEFAULT_DOT_RADIUS from manimlib.mobject.geometry import DEFAULT_DOT_RADIUS
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import stretch_array_to_length from manimlib.utils.iterables import resize_preserving_order
class DotCloud(PMobject): class DotCloud(PMobject):
@ -32,7 +32,7 @@ class DotCloud(PMobject):
def set_points(self, points): def set_points(self, points):
super().set_points(points) super().set_points(points)
self.radii = stretch_array_to_length(self.radii, len(points)) self.radii = resize_preserving_order(self.radii, len(points))
return self return self
def set_points_by_grid(self, n_rows, n_cols, height=None, width=None): def set_points_by_grid(self, n_rows, n_cols, height=None, width=None):
@ -58,7 +58,7 @@ class DotCloud(PMobject):
if isinstance(radii, numbers.Number): if isinstance(radii, numbers.Number):
self.radii[:] = radii self.radii[:] = radii
else: else:
self.radii = stretch_array_to_length(radii, len(self.points)) self.radii = resize_preserving_order(radii, len(self.points))
return self return self
def scale(self, scale_factor, scale_radii=True, **kwargs): def scale(self, scale_factor, scale_radii=True, **kwargs):

View File

@ -4,7 +4,7 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.color import color_gradient from manimlib.utils.color import color_gradient
from manimlib.utils.color import color_to_rgba from manimlib.utils.color import color_to_rgba
from manimlib.utils.color import rgba_to_color from manimlib.utils.color import rgba_to_color
from manimlib.utils.iterables import stretch_array_to_length from manimlib.utils.iterables import resize_preserving_order
class PMobject(Mobject): class PMobject(Mobject):
@ -18,7 +18,7 @@ class PMobject(Mobject):
def set_points(self, points): def set_points(self, points):
self.points = points self.points = points
self.rgbas = stretch_array_to_length(self.rgbas, len(points)) self.rgbas = resize_preserving_order(self.rgbas, len(points))
return self return self
def add_points(self, points, rgbas=None, color=None, alpha=1): def add_points(self, points, rgbas=None, color=None, alpha=1):
@ -91,7 +91,7 @@ class PMobject(Mobject):
return self return self
def match_colors(self, pmobject): def match_colors(self, pmobject):
self.rgbas[:] = stretch_array_to_length(pmobject.rgbas, len(self.points)) self.rgbas[:] = resize_preserving_order(pmobject.rgbas, len(self.points))
return self return self
def filter_out(self, condition): def filter_out(self, condition):
@ -138,14 +138,6 @@ class PMobject(Mobject):
return self.points[index] return self.points[index]
# Alignment # Alignment
def align_points_with_larger(self, larger_mobject):
assert(isinstance(larger_mobject, PMobject))
self.apply_over_attr_arrays(
lambda a: stretch_array_to_length(
a, larger_mobject.get_num_points()
)
)
def interpolate_color(self, mobject1, mobject2, alpha): def interpolate_color(self, mobject1, mobject2, alpha):
self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha) self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha)
return self return self

View File

@ -44,7 +44,12 @@ class ParametricSurface(Mobject):
self.uv_func = uv_func self.uv_func = uv_func
self.compute_triangle_indices() self.compute_triangle_indices()
super().__init__(**kwargs) super().__init__(**kwargs)
self.sort_faces_back_to_front()
def init_data(self):
self.data = {
"points": np.zeros((0, 3)),
"rgba": np.zeros((1, 4)),
}
def init_points(self): def init_points(self):
dim = self.dim dim = self.dim
@ -65,7 +70,10 @@ class ParametricSurface(Mobject):
# infinitesimal nudged values alongside the original values. This way, one # infinitesimal nudged values alongside the original values. This way, one
# can perform all the manipulations they'd like to the surface, and normals # can perform all the manipulations they'd like to the surface, and normals
# are still easily recoverable. # are still easily recoverable.
self.points = np.vstack(point_lists) self.set_points(np.vstack(point_lists))
def init_colors(self):
self.set_color(self.color, self.opacity)
def compute_triangle_indices(self): def compute_triangle_indices(self):
# TODO, if there is an event which changes # TODO, if there is an event which changes
@ -88,13 +96,10 @@ class ParametricSurface(Mobject):
def get_triangle_indices(self): def get_triangle_indices(self):
return self.triangle_indices return self.triangle_indices
def init_colors(self):
self.rgbas = np.zeros((1, 4))
self.set_color(self.color, self.opacity)
def get_surface_points_and_nudged_points(self): def get_surface_points_and_nudged_points(self):
k = len(self.points) // 3 points = self.get_points()
return self.points[:k], self.points[k:2 * k], self.points[2 * k:] k = len(points) // 3
return points[:k], points[k:2 * k], points[2 * k:]
def get_unit_normals(self): def get_unit_normals(self):
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
@ -104,40 +109,38 @@ class ParametricSurface(Mobject):
) )
return normalize_along_axis(normals, 1) return normalize_along_axis(normals, 1)
def set_color(self, color, opacity=1.0, family=True): def set_color(self, color, opacity=None, family=True):
# TODO, allow for multiple colors # TODO, allow for multiple colors
if opacity is None:
opacity = self.data["rgba"][0, 3]
rgba = color_to_rgba(color, opacity) rgba = color_to_rgba(color, opacity)
mobs = self.get_family() if family else [self] mobs = self.get_family() if family else [self]
for mob in mobs: for mob in mobs:
mob.rgbas[:] = rgba mob.data["rgba"][:] = rgba
return self return self
def get_color(self): def get_color(self):
return rgb_to_color(self.rgbas[0, :3]) return rgb_to_color(self.data["rgba"][0, :3])
def set_opacity(self, opacity, family=True): def set_opacity(self, opacity, family=True):
mobs = self.get_family() if family else [self] mobs = self.get_family() if family else [self]
for mob in mobs: for mob in mobs:
mob.rgbas[:, 3] = opacity mob.data["rgba"][:, 3] = opacity
return self
def interpolate_color(self, mobject1, mobject2, alpha):
self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha)
return self return self
def pointwise_become_partial(self, smobject, a, b, axis=None): def pointwise_become_partial(self, smobject, a, b, axis=None):
if axis is None: if axis is None:
axis = self.prefered_creation_axis axis = self.prefered_creation_axis
assert(isinstance(smobject, ParametricSurface)) assert(isinstance(smobject, ParametricSurface))
self.points[:] = smobject.points[:]
if a <= 0 and b >= 1: if a <= 0 and b >= 1:
self.set_points(smobject.points)
return self return self
nu, nv = smobject.resolution nu, nv = smobject.resolution
self.points[:] = np.vstack([ self.set_points(np.vstack([
self.get_partial_points_array(arr, a, b, (nu, nv, 3), axis=axis) self.get_partial_points_array(arr, a, b, (nu, nv, 3), axis=axis)
for arr in self.get_surface_points_and_nudged_points() for arr in smobject.get_surface_points_and_nudged_points()
]) ]))
return self return self
def get_partial_points_array(self, points, a, b, resolution, axis): def get_partial_points_array(self, points, a, b, resolution, axis):
@ -178,7 +181,8 @@ class ParametricSurface(Mobject):
return data return data
def fill_in_shader_color_info(self, data): def fill_in_shader_color_info(self, data):
data["color"] = self.rgbas # TODO, what if len(self.data["rgba"]) > 1?
data["color"] = self.data["rgba"]
return data return data
def get_shader_vert_indices(self): def get_shader_vert_indices(self):
@ -195,7 +199,7 @@ class SGroup(ParametricSurface):
self.add(*parametric_surfaces) self.add(*parametric_surfaces)
def init_points(self): def init_points(self):
self.points = np.zeros((0, 3)) pass # Needed?
class TexturedSurface(ParametricSurface): class TexturedSurface(ParametricSurface):
@ -229,40 +233,40 @@ class TexturedSurface(ParametricSurface):
self.u_range = uv_surface.u_range self.u_range = uv_surface.u_range
self.v_range = uv_surface.v_range self.v_range = uv_surface.v_range
self.resolution = uv_surface.resolution self.resolution = uv_surface.resolution
self.gloss = self.uv_surface.gloss
super().__init__(self.uv_func, **kwargs) super().__init__(self.uv_func, **kwargs)
def init_points(self): def init_data(self):
self.points = self.uv_surface.points
# Init im_coords
nu, nv = self.uv_surface.resolution nu, nv = self.uv_surface.resolution
u_range = np.linspace(0, 1, nu) self.data = {
v_range = np.linspace(1, 0, nv) # Reverse y-direction "points": self.uv_surface.get_points(),
uv_grid = np.array([[u, v] for u in u_range for v in v_range]) "im_coords": np.array([
self.im_coords = uv_grid [u, v]
for u in np.linspace(0, 1, nu)
for v in np.linspace(1, 0, nv) # Reverse y-direction
]),
"opacity": np.array([self.uv_surface.data["rgba"][:, 3]]),
}
def init_colors(self): def init_colors(self):
self.opacity = self.uv_surface.rgbas[:, 3] pass
self.gloss = self.uv_surface.gloss
def interpolate_color(self, mobject1, mobject2, alpha):
# TODO, handle multiple textures
self.opacity = interpolate(mobject1.opacity, mobject2.opacity, alpha)
return self
def set_opacity(self, opacity, family=True): def set_opacity(self, opacity, family=True):
self.opacity = opacity mobs = self.get_family() if family else [self]
if family: for mob in mobs:
for sm in self.submobjects: mob.data["opacity"][:] = opacity
sm.set_opacity(opacity, family)
return self return self
def pointwise_become_partial(self, tsmobject, a, b, axis=1): def pointwise_become_partial(self, tsmobject, a, b, axis=1):
super().pointwise_become_partial(tsmobject, a, b, axis) super().pointwise_become_partial(tsmobject, a, b, axis)
self.im_coords[:] = tsmobject.im_coords im_coords = self.data["im_coords"]
im_coords[:] = tsmobject.data["im_coords"]
if a <= 0 and b >= 1: if a <= 0 and b >= 1:
return self return self
nu, nv = tsmobject.resolution nu, nv = tsmobject.resolution
self.im_coords[:] = self.get_partial_points_array(self.im_coords, a, b, (nu, nv, 2), axis) im_coords[:] = self.get_partial_points_array(
im_coords, a, b, (nu, nv, 2), axis
)
return self return self
def get_shader_uniforms(self): def get_shader_uniforms(self):
@ -271,6 +275,6 @@ class TexturedSurface(ParametricSurface):
return result return result
def fill_in_shader_color_info(self, data): def fill_in_shader_color_info(self, data):
data["im_coords"] = self.im_coords data["im_coords"] = self.data["im_coords"]
data["opacity"] = self.opacity data["opacity"] = self.data["opacity"]
return data return data

View File

@ -2,7 +2,6 @@ import itertools as it
import operator as op import operator as op
import moderngl import moderngl
from colour import Color
from functools import reduce from functools import reduce
from manimlib.constants import * from manimlib.constants import *
@ -17,10 +16,12 @@ from manimlib.utils.bezier import set_array_by_interpolation
from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import partial_quadratic_bezier_points from manimlib.utils.bezier import partial_quadratic_bezier_points
from manimlib.utils.color import color_to_rgba from manimlib.utils.color import color_to_rgba
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex from manimlib.utils.color import rgb_to_hex
from manimlib.utils.iterables import make_even from manimlib.utils.iterables import make_even
from manimlib.utils.iterables import stretch_array_to_length from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import stretch_array_to_length_with_interpolation from manimlib.utils.iterables import resize_preserving_order
from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.iterables import listify from manimlib.utils.iterables import listify
from manimlib.utils.paths import straight_path from manimlib.utils.paths import straight_path
from manimlib.utils.space_ops import angle_between_vectors from manimlib.utils.space_ops import angle_between_vectors
@ -77,68 +78,27 @@ class VMobject(Mobject):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.unit_normal_locked = False self.unit_normal_locked = False
self.triangulation_locked = False self.needs_new_triangulation = True
super().__init__(**kwargs) super().__init__(**kwargs)
self.lock_unit_normal(family=False) self.lock_unit_normal(family=False)
self.lock_triangulation(family=False)
def get_group_class(self): def get_group_class(self):
return VGroup return VGroup
# To sort out later
def init_data(self): def init_data(self):
self.fill_data = np.zeros(0, dtype=self.fill_dtype) self.data = {
self.stroke_data = np.zeros(0, dtype=self.stroke_dtype) "points": np.zeros((0, 3)),
"fill_rgba": np.zeros((1, 4)),
def resize_data(self, new_length): "stroke_rgba": np.zeros((1, 4)),
self.stroke_data = np.resize(self.stroke_data, new_length) "stroke_width": np.zeros((1, 1)),
self.fill_data = np.resize(self.fill_data, new_length) }
self.fill_data["vert_index"][:, 0] = range(new_length)
def set_points(self, points): def set_points(self, points):
if len(points) != len(self.stroke_data): super().set_points(points)
self.resize_data(len(points)) self.refresh_triangulation()
nppc = self.n_points_per_curve
self.stroke_data["point"] = points
self.stroke_data["prev_point"][:nppc] = points[-nppc:]
self.stroke_data["prev_point"][nppc:] = points[:-nppc]
self.stroke_data["next_point"][:-nppc] = points[nppc:]
self.stroke_data["next_point"][-nppc:] = points[:nppc]
self.fill_data["point"] = points
# # TODO, only do conditionally
# unit_normal = self.get_unit_normal()
# self.stroke_data["unit_normal"] = unit_normal
# self.fill_data["unit_normal"] = unit_normal
# self.refresh_triangulation()
def get_points(self):
return self.stroke_data["point"]
def get_all_point_arrays(self):
return [
self.fill_data["point"],
self.stroke_data["point"],
self.stroke_data["prev_point"],
self.stroke_data["next_point"],
]
def get_all_data_arrays(self):
return [self.fill_data, self.stroke_data]
def get_data_array_attrs(self):
return ["fill_data", "stroke_data"]
def get_num_points(self):
return len(self.stroke_data)
# Colors # Colors
def init_colors(self): def init_colors(self):
self.fill_rgbas = np.zeros((1, 4))
self.stroke_rgbas = np.zeros((1, 4))
self.set_fill( self.set_fill(
color=self.fill_color or self.color, color=self.fill_color or self.color,
opacity=self.fill_opacity, opacity=self.fill_opacity,
@ -153,69 +113,45 @@ class VMobject(Mobject):
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
return self return self
def generate_rgba_array(self, color, opacity): def set_rgba_array(self, name, color, opacity, family=True):
""" # TODO, account for if color or opacity are tuples
First arg can be either a color, or a tuple/list of colors. rgb = color_to_rgb(color) if color else None
Likewise, opacity can either be a float, or a tuple of floats. mobs = self.get_family() if family else [self]
""" for mob in mobs:
colors = listify(color) if rgb is not None:
opacities = listify(opacity) mob.data[name][:, :3] = rgb
return np.array([
color_to_rgba(c, o)
for c, o in zip(*make_even(colors, opacities))
])
def update_rgbas_array(self, array_name, color, opacity):
rgbas = self.generate_rgba_array(color or BLACK, opacity or 0)
# Match up current rgbas array with the newly calculated
# one. 99% of the time they'll be the same.
curr_rgbas = getattr(self, array_name)
if len(curr_rgbas) < len(rgbas):
curr_rgbas = stretch_array_to_length(curr_rgbas, len(rgbas))
setattr(self, array_name, curr_rgbas)
elif len(rgbas) < len(curr_rgbas):
rgbas = stretch_array_to_length(rgbas, len(curr_rgbas))
# Only update rgb if color was not None, and only
# update alpha channel if opacity was passed in
if color is not None:
curr_rgbas[:, :3] = rgbas[:, :3]
if opacity is not None: if opacity is not None:
curr_rgbas[:, 3] = rgbas[:, 3] mob.data[name][:, 3] = opacity
return self return self
def set_fill(self, color=None, opacity=None, family=True): def set_fill(self, color=None, opacity=None, family=True):
if family: self.set_rgba_array('fill_rgba', color, opacity, family)
for sm in self.submobjects:
sm.set_fill(color, opacity, family)
self.update_rgbas_array("fill_rgbas", color, opacity)
return self
def set_stroke(self, color=None, width=None, opacity=None, def set_stroke(self, color=None, width=None, opacity=None, background=None, family=True):
background=None, family=True): self.set_rgba_array('stroke_rgba', color, opacity, family)
if family:
for sm in self.submobjects: mobs = self.get_family() if family else [self]
sm.set_stroke(color, width, opacity, background, family) for mob in mobs:
self.update_rgbas_array("stroke_rgbas", color, opacity)
if width is not None: if width is not None:
self.stroke_width = np.array(listify(width), dtype=float) # TODO, account for if width is an array
mob.data['stroke_width'][:] = width
if background is not None: if background is not None:
self.draw_stroke_behind_fill = background mob.draw_stroke_behind_fill = background
return self return self
def set_style(self, def set_style(self,
fill_color=None, fill_color=None,
fill_opacity=None, fill_opacity=None,
fill_rgbas=None, fill_rgba=None,
stroke_color=None, stroke_color=None,
stroke_opacity=None, stroke_opacity=None,
stroke_rgbas=None, stroke_rgba=None,
stroke_width=None, stroke_width=None,
gloss=None, gloss=None,
shadow=None, shadow=None,
background_image_file=None,
family=True): family=True):
if fill_rgbas is not None: if fill_rgba is not None:
self.fill_rgbas = np.array(fill_rgbas) self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
else: else:
self.set_fill( self.set_fill(
color=fill_color, color=fill_color,
@ -223,10 +159,9 @@ class VMobject(Mobject):
family=family family=family
) )
if stroke_rgbas is not None: if stroke_rgba is not None:
self.stroke_rgbas = np.array(stroke_rgbas) self.data['stroke_rgba'] = resize_with_interpolation(stroke_rgba, len(fill_rgba))
if stroke_width is not None: self.set_stroke(width=stroke_width)
self.stroke_width = np.array(listify(stroke_width))
else: else:
self.set_stroke( self.set_stroke(
color=stroke_color, color=stroke_color,
@ -239,31 +174,19 @@ class VMobject(Mobject):
self.set_gloss(gloss, family=family) self.set_gloss(gloss, family=family)
if shadow is not None: if shadow is not None:
self.set_shadow(shadow, family=family) self.set_shadow(shadow, family=family)
if background_image_file:
self.color_using_background_image(background_image_file)
return self return self
def get_style(self): def get_style(self):
return { return {
"fill_rgbas": self.get_fill_rgbas(), "fill_rgba": self.data['fill_rgba'],
"stroke_rgbas": self.get_stroke_rgbas(), "stroke_rgba": self.data['stroke_rgba'],
"stroke_width": self.stroke_width, "stroke_width": self.data['stroke_width'],
"gloss": self.get_gloss(), "gloss": self.get_gloss(),
"shadow": self.get_shadow(), "shadow": self.get_shadow(),
"background_image_file": self.get_background_image_file(),
} }
def match_style(self, vmobject, family=True): def match_style(self, vmobject, family=True):
for name, value in vmobject.get_style().items(): self.set_style(**vmobject.get_style(), family=False)
if isinstance(value, np.ndarray):
curr = getattr(self, name)
if curr.size == value.size:
curr[:] = value[:]
else:
setattr(self, name, np.array(value))
else:
setattr(self, name, value)
if family: if family:
# Does its best to match up submobject lists, and # Does its best to match up submobject lists, and
# match styles accordingly # match styles accordingly
@ -299,12 +222,29 @@ class VMobject(Mobject):
super().fade(darkness, family) super().fade(darkness, family)
return self return self
def get_fill_rgbas(self): def get_fill_colors(self):
try: return [
return self.fill_rgbas rgb_to_hex(rgba[:3])
except AttributeError: for rgba in self.data['fill_rgba']
return np.zeros((1, 4)) ]
def get_fill_opacities(self):
return self.data['fill_rgba'][:, 3]
def get_stroke_colors(self):
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['stroke_rgba']
]
def get_stroke_opacities(self):
return self.data['stroke_rgba'][:, 3]
def get_stroke_widths(self):
return self.data['stroke_width']
# TODO, it's weird for these to return the first of various lists
# rather than the full information
def get_fill_color(self): def get_fill_color(self):
""" """
If there are multiple colors (for gradient) If there are multiple colors (for gradient)
@ -319,62 +259,25 @@ class VMobject(Mobject):
""" """
return self.get_fill_opacities()[0] return self.get_fill_opacities()[0]
def get_fill_colors(self):
return [
Color(rgb=rgba[:3])
for rgba in self.get_fill_rgbas()
]
def get_fill_opacities(self):
return self.get_fill_rgbas()[:, 3]
def get_stroke_rgbas(self):
try:
return self.stroke_rgbas
except AttributeError:
return np.zeros((1, 4))
# TODO, it's weird for these to return the first of various lists
# rather than the full information
def get_stroke_color(self): def get_stroke_color(self):
return self.get_stroke_colors()[0] return self.get_stroke_colors()[0]
def get_stroke_width(self): def get_stroke_width(self):
return self.stroke_width[0] return self.get_stroke_widths()[0]
def get_stroke_opacity(self): def get_stroke_opacity(self):
return self.get_stroke_opacities()[0] return self.get_stroke_opacities()[0]
def get_stroke_colors(self):
return [
rgb_to_hex(rgba[:3])
for rgba in self.get_stroke_rgbas()
]
def get_stroke_opacities(self):
return self.get_stroke_rgbas()[:, 3]
def get_color(self): def get_color(self):
if np.all(self.get_fill_opacities() == 0): if self.has_stroke():
return self.get_stroke_color() return self.get_stroke_color()
return self.get_fill_color() return self.get_fill_color()
def has_stroke(self): def has_stroke(self):
if len(self.stroke_width) == 1: return any(self.get_stroke_widths()) and any(self.get_stroke_opacities())
if self.stroke_width == 0:
return False
elif not self.stroke_width.any():
return False
alphas = self.stroke_rgbas[:, 3]
if len(alphas) == 1:
return alphas[0] > 0
return alphas.any()
def has_fill(self): def has_fill(self):
alphas = self.fill_rgbas[:, 3] return any(self.get_fill_opacities())
if len(alphas) == 1:
return alphas[0] > 0
return alphas.any()
def get_opacity(self): def get_opacity(self):
if self.has_fill(): if self.has_fill():
@ -382,45 +285,14 @@ class VMobject(Mobject):
return self.get_stroke_opacity() return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke=True, family=True): def set_flat_stroke(self, flat_stroke=True, family=True):
self.flat_stroke = flat_stroke mobs = self.get_family() if family else [self]
if family: for mob in mobs:
for submob in self.submobjects: mob.flat_stroke = flat_stroke
submob.set_flat_stroke(flat_stroke, family)
return self return self
def get_flat_stroke(self): def get_flat_stroke(self):
return self.flat_stroke return self.flat_stroke
# TODO, this currently does nothing
def color_using_background_image(self, background_image_file):
self.background_image_file = background_image_file
self.set_color(WHITE)
for submob in self.submobjects:
submob.color_using_background_image(background_image_file)
return self
def get_background_image_file(self):
return self.background_image_file
def match_background_image_file(self, vmobject):
self.color_using_background_image(vmobject.get_background_image_file())
return self
def stretched_style_array_matching_points(self, array):
new_len = self.get_num_points()
long_arr = stretch_array_to_length_with_interpolation(
array, 1 + 2 * (new_len // 3)
)
shape = array.shape
if len(shape) > 1:
result = np.zeros((new_len, shape[1]))
else:
result = np.zeros(new_len)
result[0::3] = long_arr[0:-1:2]
result[1::3] = long_arr[1::2]
result[2::3] = long_arr[2::2]
return result
# Points # Points
def set_anchors_and_handles(self, anchors1, handles, anchors2): def set_anchors_and_handles(self, anchors1, handles, anchors2):
assert(len(anchors1) == len(handles) == len(anchors2)) assert(len(anchors1) == len(handles) == len(anchors2))
@ -436,7 +308,8 @@ class VMobject(Mobject):
# TODO, check that number new points is a multiple of 4? # TODO, check that number new points is a multiple of 4?
# or else that if self.get_num_points() % 4 == 1, then # or else that if self.get_num_points() % 4 == 1, then
# len(new_points) % 4 == 3? # len(new_points) % 4 == 3?
self.set_points(np.vstack([self.get_points(), new_points])) self.resize_points(self.get_num_points() + len(new_points))
self.data["points"][-len(new_points):] = new_points
return self return self
def start_new_path(self, point): def start_new_path(self, point):
@ -784,7 +657,6 @@ class VMobject(Mobject):
# Alignment # Alignment
def align_points(self, vmobject): def align_points(self, vmobject):
self.align_rgbas(vmobject)
if self.get_num_points() == len(vmobject.get_points()): if self.get_num_points() == len(vmobject.get_points()):
return return
@ -871,39 +743,20 @@ class VMobject(Mobject):
new_points += partial_quadratic_bezier_points(group, a1, a2) new_points += partial_quadratic_bezier_points(group, a1, a2)
return np.vstack(new_points) return np.vstack(new_points)
def align_rgbas(self, vmobject): def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs):
attrs = ["fill_rgbas", "stroke_rgbas"] super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
for attr in attrs: if self.has_fill():
a1 = getattr(self, attr) tri1 = mobject1.get_triangulation()
a2 = getattr(vmobject, attr) tri2 = mobject2.get_triangulation()
if len(a1) > len(a2): if len(tri1) != len(tri1) or not all(tri1 == tri2):
new_a2 = stretch_array_to_length(a2, len(a1)) self.refresh_triangulation()
setattr(vmobject, attr, new_a2)
elif len(a2) > len(a1):
new_a1 = stretch_array_to_length(a1, len(a2))
setattr(self, attr, new_a1)
return self return self
def interpolate_color(self, mobject1, mobject2, alpha):
attrs = [
"fill_rgbas",
"stroke_rgbas",
"stroke_width",
]
for attr in attrs:
set_array_by_interpolation(
getattr(self, attr),
getattr(mobject1, attr),
getattr(mobject2, attr),
alpha
)
def pointwise_become_partial(self, vmobject, a, b): def pointwise_become_partial(self, vmobject, a, b):
assert(isinstance(vmobject, VMobject)) assert(isinstance(vmobject, VMobject))
self.set_points(vmobject.get_points())
if a <= 0 and b >= 1: if a <= 0 and b >= 1:
return self return self
num_curves = self.get_num_curves() num_curves = vmobject.get_num_curves()
nppc = self.n_points_per_curve nppc = self.n_points_per_curve
# Partial curve includes three portions: # Partial curve includes three portions:
@ -918,26 +771,26 @@ class VMobject(Mobject):
i3 = nppc * upper_index i3 = nppc * upper_index
i4 = nppc * (upper_index + 1) i4 = nppc * (upper_index + 1)
points = self.get_points()
vm_points = vmobject.get_points() vm_points = vmobject.get_points()
new_points = vm_points.copy()
if num_curves == 0: if num_curves == 0:
points[:] = 0 new_points[:] = 0
return self return self
if lower_index == upper_index: if lower_index == upper_index:
tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, upper_residue) tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, upper_residue)
points[:i1] = tup[0] new_points[:i1] = tup[0]
points[i1:i4] = tup new_points[i1:i4] = tup
points[i4:] = tup[2] new_points[i4:] = tup[2]
points[nppc:] = points[nppc - 1] new_points[nppc:] = new_points[nppc - 1]
else: else:
low_tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, 1) low_tup = partial_quadratic_bezier_points(vm_points[i1:i2], lower_residue, 1)
high_tup = partial_quadratic_bezier_points(vm_points[i3:i4], 0, upper_residue) high_tup = partial_quadratic_bezier_points(vm_points[i3:i4], 0, upper_residue)
points[0:i1] = low_tup[0] new_points[0:i1] = low_tup[0]
points[i1:i2] = low_tup new_points[i1:i2] = low_tup
# Keep points i2:i3 as they are # Keep new_points i2:i3 as they are
points[i3:i4] = high_tup new_points[i3:i4] = high_tup
points[i4:] = high_tup[2] new_points[i4:] = high_tup[2]
self.set_points(points) self.set_points(new_points)
return self return self
def get_subcurve(self, a, b): def get_subcurve(self, a, b):
@ -945,14 +798,10 @@ class VMobject(Mobject):
vmob.pointwise_become_partial(self, a, b) vmob.pointwise_become_partial(self, a, b)
return vmob return vmob
def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path):
super().interpolate(mobject1, mobject2, alpha, path_func)
if not np.all(mobject1.get_triangulation() == mobject2.get_triangulation()):
self.refresh_triangulation()
return self
# For shaders # For shaders
def init_shader_data(self): def init_shader_data(self):
self.fill_data = np.zeros(0, dtype=self.fill_dtype)
self.stroke_data = np.zeros(0, dtype=self.stroke_dtype)
self.fill_shader_wrapper = ShaderWrapper( self.fill_shader_wrapper = ShaderWrapper(
vert_data=self.fill_data, vert_data=self.fill_data,
vert_indices=np.zeros(0, dtype='i4'), vert_indices=np.zeros(0, dtype='i4'),
@ -1019,39 +868,27 @@ class VMobject(Mobject):
return result return result
def get_stroke_shader_data(self): def get_stroke_shader_data(self):
# TODO, make even simpler after fixing colors points = self.get_points()
rgbas = self.get_stroke_rgbas() if len(self.stroke_data) != len(points):
if len(rgbas) > 1: self.stroke_data = resize_array(self.stroke_data, len(points))
rgbas = self.stretched_style_array_matching_points(rgbas) # TODO, account for when self.data["stroke_width"] and self.data["stroke_rgba"]
# have length greater than 1
stroke_width = self.stroke_width nppc = self.n_points_per_curve
if len(stroke_width) > 1: self.stroke_data["point"] = points
stroke_width = self.stretched_style_array_matching_points(stroke_width) self.stroke_data["prev_point"][:nppc] = points[-nppc:]
self.stroke_data["prev_point"][nppc:] = points[:-nppc]
self.stroke_data["next_point"][:-nppc] = points[nppc:]
self.stroke_data["next_point"][-nppc:] = points[:nppc]
data = self.stroke_data self.stroke_data["unit_normal"] = self.get_unit_normal()
data["stroke_width"][:, 0] = stroke_width self.stroke_data["stroke_width"] = self.data["stroke_width"]
data["color"] = rgbas self.stroke_data["color"] = self.data["stroke_rgba"]
return data return self.stroke_data
def lock_triangulation(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.triangulation_locked = False
mob.saved_triangulation = mob.get_triangulation()
mob.triangulation_locked = True
return self
def unlock_triangulation(self):
for sm in self.get_family():
sm.triangulation_locked = False
return self
def refresh_triangulation(self): def refresh_triangulation(self):
for mob in self.get_family(): for mob in self.get_family():
if mob.triangulation_locked: mob.needs_new_triangulation = True
mob.triangulation_locked = False
mob.saved_triangulation = mob.get_triangulation()
mob.triangulation_locked = True
return self return self
def get_triangulation(self, normal_vector=None): def get_triangulation(self, normal_vector=None):
@ -1061,13 +898,14 @@ class VMobject(Mobject):
if normal_vector is None: if normal_vector is None:
normal_vector = self.get_unit_normal() normal_vector = self.get_unit_normal()
if self.triangulation_locked: if not self.needs_new_triangulation:
return self.saved_triangulation return self.saved_traignulation
points = self.get_points() points = self.get_points()
if len(points) <= 1: if len(points) <= 1:
return np.zeros(0, dtype='i4') self.saved_traignulation == np.zeros(0, dtype='i4')
return self.saved_traignulation
# Rotate points such that unit normal vector is OUT # Rotate points such that unit normal vector is OUT
# TODO, 99% of the time this does nothing. Do a check for that? # TODO, 99% of the time this does nothing. Do a check for that?
@ -1104,14 +942,20 @@ class VMobject(Mobject):
inner_tri_indices = inner_vert_indices[earclip_triangulation(inner_verts, rings)] inner_tri_indices = inner_vert_indices[earclip_triangulation(inner_verts, rings)]
tri_indices = np.hstack([indices, inner_tri_indices]) tri_indices = np.hstack([indices, inner_tri_indices])
self.saved_traignulation = tri_indices
self.needs_new_triangulation = False
return tri_indices return tri_indices
def get_fill_shader_data(self): def get_fill_shader_data(self):
# TODO, make simpler points = self.get_points()
rgbas = self.get_fill_rgbas()[:1] if len(self.fill_data) != len(points):
data = self.fill_data self.fill_data = resize_array(self.fill_data, len(points))
data["color"] = rgbas self.fill_data["vert_index"][:, 0] = range(len(points))
return data
self.fill_data["point"] = points
self.fill_data["unit_normal"] = self.get_unit_normal()
self.fill_data["color"] = self.data["fill_rgba"]
return self.fill_data
def get_fill_shader_vert_indices(self): def get_fill_shader_vert_indices(self):
return self.get_triangulation() return self.get_triangulation()
@ -1121,11 +965,11 @@ class VGroup(VMobject):
def __init__(self, *vmobjects, **kwargs): def __init__(self, *vmobjects, **kwargs):
if not all([isinstance(m, VMobject) for m in vmobjects]): if not all([isinstance(m, VMobject) for m in vmobjects]):
raise Exception("All submobjects must be of type VMobject") raise Exception("All submobjects must be of type VMobject")
VMobject.__init__(self, **kwargs) super().__init__(**kwargs)
self.add(*vmobjects) self.add(*vmobjects)
class VectorizedPoint(VMobject, Point): class VectorizedPoint(Point, VMobject):
CONFIG = { CONFIG = {
"color": BLACK, "color": BLACK,
"fill_opacity": 0, "fill_opacity": 0,
@ -1135,13 +979,13 @@ class VectorizedPoint(VMobject, Point):
} }
def __init__(self, location=ORIGIN, **kwargs): def __init__(self, location=ORIGIN, **kwargs):
VMobject.__init__(self, **kwargs) super().__init__(**kwargs)
self.set_points(np.array([location])) self.set_points(np.array([location]))
class CurvesAsSubmobjects(VGroup): class CurvesAsSubmobjects(VGroup):
def __init__(self, vmobject, **kwargs): def __init__(self, vmobject, **kwargs):
VGroup.__init__(self, **kwargs) super().__init__(**kwargs)
for tup in vmobject.get_bezier_tuples(): for tup in vmobject.get_bezier_tuples():
part = VMobject() part = VMobject()
part.set_points(tup) part.set_points(tup)
@ -1157,7 +1001,7 @@ class DashedVMobject(VMobject):
} }
def __init__(self, vmobject, **kwargs): def __init__(self, vmobject, **kwargs):
VMobject.__init__(self, **kwargs) super().__init__(**kwargs)
num_dashes = self.num_dashes num_dashes = self.num_dashes
ps_ratio = self.positive_space_ratio ps_ratio = self.positive_space_ratio
if num_dashes > 0: if num_dashes > 0:

View File

@ -74,8 +74,8 @@ def interpolate(start, end, alpha):
sys.exit(2) sys.exit(2)
def set_array_by_interpolation(arr, arr1, arr2, alpha): def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate):
arr[:] = interpolate(arr1, arr2, alpha) arr[:] = interp_func(arr1, arr2, alpha)
return arr return arr

View File

@ -80,15 +80,23 @@ def listify(obj):
return [obj] return [obj]
def stretch_array_to_length(nparray, length): def resize_array(nparray, length):
# TODO, rename to "resize"? return np.resize(nparray, (length, *nparray.shape[1:]))
def resize_preserving_order(nparray, length):
if len(nparray) == 0:
return np.zeros((length, *nparray.shape[1:]))
if len(nparray) == length:
return nparray
indices = np.arange(length) * len(nparray) // length indices = np.arange(length) * len(nparray) // length
return nparray[indices] return nparray[indices]
def stretch_array_to_length_with_interpolation(nparray, length): def resize_with_interpolation(nparray, length):
curr_len = len(nparray) if len(nparray) == length:
cont_indices = np.linspace(0, curr_len - 1, length) return nparray
cont_indices = np.linspace(0, len(nparray) - 1, length)
return np.array([ return np.array([
(1 - a) * nparray[lh] + a * nparray[rh] (1 - a) * nparray[lh] + a * nparray[rh]
for ci in cont_indices for ci in cont_indices