mirror of
https://github.com/3b1b/manim.git
synced 2025-08-02 19:46:21 +08:00
First pass at changing data to structure numpy array
This doesn't yet tackle Surface
This commit is contained in:
@ -66,7 +66,12 @@ class Mobject(object):
|
||||
# Must match in attributes of vert shader
|
||||
shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
|
||||
('point', np.float32, (3,)),
|
||||
('rgba', np.float32, (4,)),
|
||||
]
|
||||
data_dtype: np.dtype = np.dtype([
|
||||
('points', '<f4', (3,)),
|
||||
('rgbas', '<f4', (4,)),
|
||||
])
|
||||
aligned_data_keys = ['points']
|
||||
|
||||
def __init__(
|
||||
@ -106,10 +111,7 @@ class Mobject(object):
|
||||
self.bounding_box: Vect3Array = np.zeros((3, 3))
|
||||
|
||||
self.init_data()
|
||||
self._data_defaults = {
|
||||
key: np.zeros((1, self.data[key].shape[1]))
|
||||
for key in self.data
|
||||
}
|
||||
self._data_defaults = np.ones(1, dtype=self.data.dtype)
|
||||
self.init_uniforms()
|
||||
self.init_updaters()
|
||||
self.init_event_listners()
|
||||
@ -131,11 +133,8 @@ class Mobject(object):
|
||||
assert(isinstance(other, int))
|
||||
return self.replicate(other)
|
||||
|
||||
def init_data(self):
|
||||
self.data: dict[str, np.ndarray] = {
|
||||
"points": np.zeros((0, 3)),
|
||||
"rgbas": np.zeros((0, 4)),
|
||||
}
|
||||
def init_data(self, length: int = 0):
|
||||
self.data = np.zeros(length, dtype=self.data_dtype)
|
||||
|
||||
def init_uniforms(self):
|
||||
self.uniforms: dict[str, float | np.ndarray] = {
|
||||
@ -152,9 +151,9 @@ class Mobject(object):
|
||||
# Typically implemented in subclass, unlpess purposefully left blank
|
||||
pass
|
||||
|
||||
def set_data(self, data: dict):
|
||||
for key in data:
|
||||
self.data[key] = data[key].copy()
|
||||
def set_data(self, data: np.ndarray):
|
||||
assert(data.dtype == self.data.dtype)
|
||||
self.data = data
|
||||
return self
|
||||
|
||||
def set_uniforms(self, uniforms: dict):
|
||||
@ -177,15 +176,12 @@ class Mobject(object):
|
||||
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
|
||||
):
|
||||
if new_length == 0:
|
||||
for key in self.data:
|
||||
if len(self.data[key]) > 0:
|
||||
self._data_defaults[key][:1] = self.data[key][:1]
|
||||
if len(self.data) > 0:
|
||||
self._data_defaults[:1] = self.data[:1]
|
||||
elif self.get_num_points() == 0:
|
||||
for key in self.data:
|
||||
self.data[key] = self._data_defaults[key].copy()
|
||||
self.data = self._data_defaults.copy()
|
||||
|
||||
for key in self.data:
|
||||
self.data[key] = resize_func(self.data[key], new_length)
|
||||
self.data = resize_func(self.data, new_length)
|
||||
self.refresh_bounding_box()
|
||||
return self
|
||||
|
||||
@ -203,8 +199,7 @@ class Mobject(object):
|
||||
|
||||
def reverse_points(self):
|
||||
for mob in self.get_family():
|
||||
for key in mob.data:
|
||||
mob.data[key] = mob.data[key][::-1]
|
||||
mob.data = mob.data[::-1]
|
||||
return self
|
||||
|
||||
def apply_points_function(
|
||||
@ -584,10 +579,7 @@ class Mobject(object):
|
||||
# The line above is only a shallow copy, so the internal
|
||||
# data which are numpyu arrays or other mobjects still
|
||||
# need to be further copied.
|
||||
result.data = {
|
||||
key: np.array(value)
|
||||
for key, value in self.data.items()
|
||||
}
|
||||
result.data = self.data.copy()
|
||||
result.uniforms = {
|
||||
key: np.array(value)
|
||||
for key, value in self.uniforms.items()
|
||||
@ -678,15 +670,22 @@ class Mobject(object):
|
||||
if len(fam1) != len(fam2):
|
||||
return False
|
||||
for m1, m2 in zip(fam1, fam2):
|
||||
for d1, d2 in [(m1.data, m2.data), (m1.uniforms, m2.uniforms)]:
|
||||
if set(d1).difference(d2):
|
||||
if m1.get_num_points() != m2.get_num_points():
|
||||
return False
|
||||
if not m1.data.dtype == m2.data.dtype:
|
||||
return False
|
||||
for key in m1.data.dtype.names:
|
||||
if not np.isclose(m1.data[key], m2.data[key]).all():
|
||||
return False
|
||||
if set(m1.uniforms).difference(m2.uniforms):
|
||||
return False
|
||||
for key in m1.uniforms:
|
||||
value1 = m1.uniforms[key]
|
||||
value2 = m2.uniforms[key]
|
||||
if isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray) and not value1.size == value2.size:
|
||||
return False
|
||||
if not np.isclose(value1, value2).all():
|
||||
return False
|
||||
for key in d1:
|
||||
if isinstance(d1[key], np.ndarray) and isinstance(d2[key], np.ndarray):
|
||||
if not d1[key].size == d2[key].size:
|
||||
return False
|
||||
if not np.isclose(d1[key], d2[key]).all():
|
||||
return False
|
||||
return True
|
||||
|
||||
def has_same_shape_as(self, mobject: Mobject) -> bool:
|
||||
@ -1604,19 +1603,7 @@ class Mobject(object):
|
||||
# In case any data arrays get resized when aligned to shader data
|
||||
mob1.refresh_shader_data()
|
||||
mob2.refresh_shader_data()
|
||||
|
||||
mob1.align_points(mob2)
|
||||
for key in mob1.data.keys() & mob2.data.keys():
|
||||
if key == "points":
|
||||
# Separate out how points are treated so that subclasses
|
||||
# can handle that case differently if they choose
|
||||
continue
|
||||
arr1 = mob1.data[key]
|
||||
arr2 = mob2.data[key]
|
||||
if len(arr2) > len(arr1):
|
||||
mob1.data[key] = resize_preserving_order(arr1, len(arr2))
|
||||
elif len(arr1) > len(arr2):
|
||||
mob2.data[key] = resize_preserving_order(arr2, len(arr1))
|
||||
|
||||
def align_points(self, mobject: Mobject):
|
||||
max_len = max(self.get_num_points(), mobject.get_num_points())
|
||||
@ -1686,13 +1673,11 @@ class Mobject(object):
|
||||
alpha: float,
|
||||
path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path
|
||||
):
|
||||
for key in self.data:
|
||||
for key in self.data.dtype.names:
|
||||
if key in self.locked_data_keys:
|
||||
continue
|
||||
if len(self.data[key]) == 0:
|
||||
continue
|
||||
if key not in mobject1.data or key not in mobject2.data:
|
||||
continue
|
||||
|
||||
func = path_func if key == "points" else interpolate
|
||||
|
||||
@ -1739,11 +1724,11 @@ class Mobject(object):
|
||||
|
||||
def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject):
|
||||
for sm, sm1, sm2 in zip(self.get_family(), mobject1.get_family(), mobject2.get_family()):
|
||||
keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys()
|
||||
sm.lock_data(list(filter(
|
||||
lambda key: arrays_match(sm1.data[key], sm2.data[key]),
|
||||
keys,
|
||||
)))
|
||||
if not (sm.data.dtype == sm1.data.dtype == sm2.data.dtype):
|
||||
sm.lock_data([
|
||||
key for key in sm.data.dtype.names
|
||||
if arrays_match(sm1.data[key], sm2.data[key])
|
||||
])
|
||||
return self
|
||||
|
||||
def unlock_data(self):
|
||||
|
@ -29,7 +29,11 @@ class DotCloud(PMobject):
|
||||
('radius', np.float32, (1,)),
|
||||
('color', np.float32, (4,)),
|
||||
]
|
||||
|
||||
data_dtype: np.dtype = np.dtype([
|
||||
('points', np.float32, (3,)),
|
||||
('radii', np.float32, (1,)),
|
||||
('rgbas', np.float32, (4,)),
|
||||
])
|
||||
def __init__(
|
||||
self,
|
||||
points: Vect3Array = NULL_POINTS,
|
||||
@ -55,7 +59,6 @@ class DotCloud(PMobject):
|
||||
|
||||
def init_data(self) -> None:
|
||||
super().init_data()
|
||||
self.data["radii"] = np.zeros((1, 1))
|
||||
self.set_radius(self.radius)
|
||||
|
||||
def init_uniforms(self) -> None:
|
||||
|
@ -24,6 +24,11 @@ class ImageMobject(Mobject):
|
||||
('im_coords', np.float32, (2,)),
|
||||
('opacity', np.float32, (1,)),
|
||||
]
|
||||
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
|
||||
('points', np.float32, (3,)),
|
||||
('im_coords', np.float32, (2,)),
|
||||
('opacity', np.float32, (1,)),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -37,11 +42,10 @@ class ImageMobject(Mobject):
|
||||
super().__init__(texture_paths={"Texture": self.image_path}, **kwargs)
|
||||
|
||||
def init_data(self) -> None:
|
||||
self.data = {
|
||||
"points": np.array([UL, DL, UR, DR]),
|
||||
"im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]),
|
||||
"opacity": self.opacity * np.ones((4, 1)),
|
||||
}
|
||||
super().init_data(length=4)
|
||||
self.data["points"][:] = [UL, DL, UR, DR]
|
||||
self.data["im_coords"][:] = [(0, 0), (0, 1), (1, 0), (1, 1)]
|
||||
self.data["opacity"][:] = self.opacity
|
||||
|
||||
def init_points(self) -> None:
|
||||
size = self.image.size
|
||||
@ -49,9 +53,10 @@ class ImageMobject(Mobject):
|
||||
self.set_height(self.height)
|
||||
|
||||
def set_opacity(self, opacity: float, recurse: bool = True):
|
||||
op_arr = np.array([[o] for o in listify(opacity)])
|
||||
for mob in self.get_family(recurse):
|
||||
mob.data["opacity"][:] = resize_with_interpolation(op_arr, mob.get_num_points())
|
||||
self.data["opacity"][:, 0] = resize_with_interpolation(
|
||||
np.array(listify(opacity)),
|
||||
self.get_num_points()
|
||||
)
|
||||
return self
|
||||
|
||||
def set_color(self, color, opacity=None, recurse=None):
|
||||
|
@ -67,9 +67,7 @@ class PMobject(Mobject):
|
||||
|
||||
def filter_out(self, condition: Callable[[np.ndarray], bool]):
|
||||
for mob in self.family_members_with_points():
|
||||
to_keep = ~np.apply_along_axis(condition, 1, mob.get_points())
|
||||
for key in mob.data:
|
||||
mob.data[key] = mob.data[key][to_keep]
|
||||
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
|
||||
return self
|
||||
|
||||
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]):
|
||||
@ -80,16 +78,13 @@ class PMobject(Mobject):
|
||||
indices = np.argsort(
|
||||
np.apply_along_axis(function, 1, mob.get_points())
|
||||
)
|
||||
for key in mob.data:
|
||||
mob.data[key][:] = mob.data[key][indices]
|
||||
mob.data[:] = mob.data[indices]
|
||||
return self
|
||||
|
||||
def ingest_submobjects(self):
|
||||
for key in self.data:
|
||||
self.data[key] = np.vstack([
|
||||
sm.data[key]
|
||||
for sm in self.get_family()
|
||||
])
|
||||
self.data = np.vstack([
|
||||
sm.data for sm in self.get_family()
|
||||
])
|
||||
return self
|
||||
|
||||
def point_from_proportion(self, alpha: float) -> np.ndarray:
|
||||
@ -99,8 +94,7 @@ class PMobject(Mobject):
|
||||
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
|
||||
lower_index = int(a * pmobject.get_num_points())
|
||||
upper_index = int(b * pmobject.get_num_points())
|
||||
for key in self.data:
|
||||
self.data[key] = pmobject.data[key][lower_index:upper_index].copy()
|
||||
self.data = pmobject.data[lower_index:upper_index].copy()
|
||||
return self
|
||||
|
||||
|
||||
|
@ -66,9 +66,16 @@ class VMobject(Mobject):
|
||||
("stroke_width", np.float32, (1,)),
|
||||
("color", np.float32, (4,)),
|
||||
]
|
||||
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
|
||||
("points", np.float32, (3,)),
|
||||
('fill_rgba', np.float32, (4,)),
|
||||
("stroke_rgba", np.float32, (4,)),
|
||||
("joint_angle", np.float32, (1,)),
|
||||
("stroke_width", np.float32, (1,)),
|
||||
('orientation', np.float32, (1,)),
|
||||
]
|
||||
fill_render_primitive: int = moderngl.TRIANGLES
|
||||
stroke_render_primitive: int = moderngl.TRIANGLE_STRIP
|
||||
aligned_data_keys = ["points", "orientation", "joint_angle"]
|
||||
|
||||
pre_function_handle_to_anchor_scale_factor: float = 0.01
|
||||
make_smooth_after_applying_functions: bool = False
|
||||
@ -117,17 +124,6 @@ class VMobject(Mobject):
|
||||
def get_group_class(self):
|
||||
return VGroup
|
||||
|
||||
def init_data(self):
|
||||
super().init_data()
|
||||
self.data.pop("rgbas")
|
||||
self.data.update({
|
||||
"fill_rgba": np.zeros((1, 4)),
|
||||
"stroke_rgba": np.zeros((1, 4)),
|
||||
"stroke_width": np.zeros((1, 1)),
|
||||
"orientation": np.ones((1, 1)),
|
||||
"joint_angle": np.zeros((0, 1)),
|
||||
})
|
||||
|
||||
def init_uniforms(self):
|
||||
super().init_uniforms()
|
||||
self.uniforms["anti_alias_width"] = self.anti_alias_width
|
||||
@ -371,23 +367,28 @@ class VMobject(Mobject):
|
||||
If there are multiple colors (for gradient)
|
||||
this returns the first one
|
||||
"""
|
||||
return self.get_fill_colors()[0]
|
||||
data = self.data if self.has_points() else self._data_defaults
|
||||
return rgb_to_hex(data["fill_rgba"][0, :3])
|
||||
|
||||
def get_fill_opacity(self) -> float:
|
||||
"""
|
||||
If there are multiple opacities, this returns the
|
||||
first
|
||||
"""
|
||||
return self.get_fill_opacities()[0]
|
||||
data = self.data if self.has_points() else self._data_defaults
|
||||
return data["fill_rgba"][0, 3]
|
||||
|
||||
def get_stroke_color(self) -> str:
|
||||
return self.get_stroke_colors()[0]
|
||||
data = self.data if self.has_points() else self._data_defaults
|
||||
return rgb_to_hex(data["stroke_rgba"][0, :3])
|
||||
|
||||
def get_stroke_width(self) -> float | np.ndarray:
|
||||
return self.get_stroke_widths()[0]
|
||||
data = self.data if self.has_points() else self._data_defaults
|
||||
return data["stroke_width"][0, 0]
|
||||
|
||||
def get_stroke_opacity(self) -> float:
|
||||
return self.get_stroke_opacities()[0]
|
||||
data = self.data if self.has_points() else self._data_defaults
|
||||
return data["stroke_rgba"][0, 3]
|
||||
|
||||
def get_color(self) -> str:
|
||||
if self.has_fill():
|
||||
@ -1134,7 +1135,7 @@ class VMobject(Mobject):
|
||||
return self
|
||||
|
||||
@triggers_refreshed_triangulation
|
||||
def set_data(self, data: dict):
|
||||
def set_data(self, data: np.ndarray):
|
||||
super().set_data(data)
|
||||
return self
|
||||
|
||||
|
Reference in New Issue
Block a user