First pass at changing data to structure numpy array

This doesn't yet tackle Surface
This commit is contained in:
Grant Sanderson
2023-01-15 16:05:18 -08:00
parent 286b8fb6c3
commit 2815f60616
5 changed files with 81 additions and 93 deletions

View File

@ -66,7 +66,12 @@ class Mobject(object):
# Must match in attributes of vert shader
shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('point', np.float32, (3,)),
('rgba', np.float32, (4,)),
]
data_dtype: np.dtype = np.dtype([
('points', '<f4', (3,)),
('rgbas', '<f4', (4,)),
])
aligned_data_keys = ['points']
def __init__(
@ -106,10 +111,7 @@ class Mobject(object):
self.bounding_box: Vect3Array = np.zeros((3, 3))
self.init_data()
self._data_defaults = {
key: np.zeros((1, self.data[key].shape[1]))
for key in self.data
}
self._data_defaults = np.ones(1, dtype=self.data.dtype)
self.init_uniforms()
self.init_updaters()
self.init_event_listners()
@ -131,11 +133,8 @@ class Mobject(object):
assert(isinstance(other, int))
return self.replicate(other)
def init_data(self):
self.data: dict[str, np.ndarray] = {
"points": np.zeros((0, 3)),
"rgbas": np.zeros((0, 4)),
}
def init_data(self, length: int = 0):
self.data = np.zeros(length, dtype=self.data_dtype)
def init_uniforms(self):
self.uniforms: dict[str, float | np.ndarray] = {
@ -152,9 +151,9 @@ class Mobject(object):
# Typically implemented in subclass, unlpess purposefully left blank
pass
def set_data(self, data: dict):
for key in data:
self.data[key] = data[key].copy()
def set_data(self, data: np.ndarray):
assert(data.dtype == self.data.dtype)
self.data = data
return self
def set_uniforms(self, uniforms: dict):
@ -177,15 +176,12 @@ class Mobject(object):
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
):
if new_length == 0:
for key in self.data:
if len(self.data[key]) > 0:
self._data_defaults[key][:1] = self.data[key][:1]
if len(self.data) > 0:
self._data_defaults[:1] = self.data[:1]
elif self.get_num_points() == 0:
for key in self.data:
self.data[key] = self._data_defaults[key].copy()
self.data = self._data_defaults.copy()
for key in self.data:
self.data[key] = resize_func(self.data[key], new_length)
self.data = resize_func(self.data, new_length)
self.refresh_bounding_box()
return self
@ -203,8 +199,7 @@ class Mobject(object):
def reverse_points(self):
for mob in self.get_family():
for key in mob.data:
mob.data[key] = mob.data[key][::-1]
mob.data = mob.data[::-1]
return self
def apply_points_function(
@ -584,10 +579,7 @@ class Mobject(object):
# The line above is only a shallow copy, so the internal
# data which are numpyu arrays or other mobjects still
# need to be further copied.
result.data = {
key: np.array(value)
for key, value in self.data.items()
}
result.data = self.data.copy()
result.uniforms = {
key: np.array(value)
for key, value in self.uniforms.items()
@ -678,15 +670,22 @@ class Mobject(object):
if len(fam1) != len(fam2):
return False
for m1, m2 in zip(fam1, fam2):
for d1, d2 in [(m1.data, m2.data), (m1.uniforms, m2.uniforms)]:
if set(d1).difference(d2):
if m1.get_num_points() != m2.get_num_points():
return False
if not m1.data.dtype == m2.data.dtype:
return False
for key in m1.data.dtype.names:
if not np.isclose(m1.data[key], m2.data[key]).all():
return False
if set(m1.uniforms).difference(m2.uniforms):
return False
for key in m1.uniforms:
value1 = m1.uniforms[key]
value2 = m2.uniforms[key]
if isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray) and not value1.size == value2.size:
return False
if not np.isclose(value1, value2).all():
return False
for key in d1:
if isinstance(d1[key], np.ndarray) and isinstance(d2[key], np.ndarray):
if not d1[key].size == d2[key].size:
return False
if not np.isclose(d1[key], d2[key]).all():
return False
return True
def has_same_shape_as(self, mobject: Mobject) -> bool:
@ -1604,19 +1603,7 @@ class Mobject(object):
# In case any data arrays get resized when aligned to shader data
mob1.refresh_shader_data()
mob2.refresh_shader_data()
mob1.align_points(mob2)
for key in mob1.data.keys() & mob2.data.keys():
if key == "points":
# Separate out how points are treated so that subclasses
# can handle that case differently if they choose
continue
arr1 = mob1.data[key]
arr2 = mob2.data[key]
if len(arr2) > len(arr1):
mob1.data[key] = resize_preserving_order(arr1, len(arr2))
elif len(arr1) > len(arr2):
mob2.data[key] = resize_preserving_order(arr2, len(arr1))
def align_points(self, mobject: Mobject):
max_len = max(self.get_num_points(), mobject.get_num_points())
@ -1686,13 +1673,11 @@ class Mobject(object):
alpha: float,
path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path
):
for key in self.data:
for key in self.data.dtype.names:
if key in self.locked_data_keys:
continue
if len(self.data[key]) == 0:
continue
if key not in mobject1.data or key not in mobject2.data:
continue
func = path_func if key == "points" else interpolate
@ -1739,11 +1724,11 @@ class Mobject(object):
def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject):
for sm, sm1, sm2 in zip(self.get_family(), mobject1.get_family(), mobject2.get_family()):
keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys()
sm.lock_data(list(filter(
lambda key: arrays_match(sm1.data[key], sm2.data[key]),
keys,
)))
if not (sm.data.dtype == sm1.data.dtype == sm2.data.dtype):
sm.lock_data([
key for key in sm.data.dtype.names
if arrays_match(sm1.data[key], sm2.data[key])
])
return self
def unlock_data(self):

View File

@ -29,7 +29,11 @@ class DotCloud(PMobject):
('radius', np.float32, (1,)),
('color', np.float32, (4,)),
]
data_dtype: np.dtype = np.dtype([
('points', np.float32, (3,)),
('radii', np.float32, (1,)),
('rgbas', np.float32, (4,)),
])
def __init__(
self,
points: Vect3Array = NULL_POINTS,
@ -55,7 +59,6 @@ class DotCloud(PMobject):
def init_data(self) -> None:
super().init_data()
self.data["radii"] = np.zeros((1, 1))
self.set_radius(self.radius)
def init_uniforms(self) -> None:

View File

@ -24,6 +24,11 @@ class ImageMobject(Mobject):
('im_coords', np.float32, (2,)),
('opacity', np.float32, (1,)),
]
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('points', np.float32, (3,)),
('im_coords', np.float32, (2,)),
('opacity', np.float32, (1,)),
]
def __init__(
self,
@ -37,11 +42,10 @@ class ImageMobject(Mobject):
super().__init__(texture_paths={"Texture": self.image_path}, **kwargs)
def init_data(self) -> None:
self.data = {
"points": np.array([UL, DL, UR, DR]),
"im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]),
"opacity": self.opacity * np.ones((4, 1)),
}
super().init_data(length=4)
self.data["points"][:] = [UL, DL, UR, DR]
self.data["im_coords"][:] = [(0, 0), (0, 1), (1, 0), (1, 1)]
self.data["opacity"][:] = self.opacity
def init_points(self) -> None:
size = self.image.size
@ -49,9 +53,10 @@ class ImageMobject(Mobject):
self.set_height(self.height)
def set_opacity(self, opacity: float, recurse: bool = True):
op_arr = np.array([[o] for o in listify(opacity)])
for mob in self.get_family(recurse):
mob.data["opacity"][:] = resize_with_interpolation(op_arr, mob.get_num_points())
self.data["opacity"][:, 0] = resize_with_interpolation(
np.array(listify(opacity)),
self.get_num_points()
)
return self
def set_color(self, color, opacity=None, recurse=None):

View File

@ -67,9 +67,7 @@ class PMobject(Mobject):
def filter_out(self, condition: Callable[[np.ndarray], bool]):
for mob in self.family_members_with_points():
to_keep = ~np.apply_along_axis(condition, 1, mob.get_points())
for key in mob.data:
mob.data[key] = mob.data[key][to_keep]
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
return self
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]):
@ -80,16 +78,13 @@ class PMobject(Mobject):
indices = np.argsort(
np.apply_along_axis(function, 1, mob.get_points())
)
for key in mob.data:
mob.data[key][:] = mob.data[key][indices]
mob.data[:] = mob.data[indices]
return self
def ingest_submobjects(self):
for key in self.data:
self.data[key] = np.vstack([
sm.data[key]
for sm in self.get_family()
])
self.data = np.vstack([
sm.data for sm in self.get_family()
])
return self
def point_from_proportion(self, alpha: float) -> np.ndarray:
@ -99,8 +94,7 @@ class PMobject(Mobject):
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points())
for key in self.data:
self.data[key] = pmobject.data[key][lower_index:upper_index].copy()
self.data = pmobject.data[lower_index:upper_index].copy()
return self

View File

@ -66,9 +66,16 @@ class VMobject(Mobject):
("stroke_width", np.float32, (1,)),
("color", np.float32, (4,)),
]
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
("points", np.float32, (3,)),
('fill_rgba', np.float32, (4,)),
("stroke_rgba", np.float32, (4,)),
("joint_angle", np.float32, (1,)),
("stroke_width", np.float32, (1,)),
('orientation', np.float32, (1,)),
]
fill_render_primitive: int = moderngl.TRIANGLES
stroke_render_primitive: int = moderngl.TRIANGLE_STRIP
aligned_data_keys = ["points", "orientation", "joint_angle"]
pre_function_handle_to_anchor_scale_factor: float = 0.01
make_smooth_after_applying_functions: bool = False
@ -117,17 +124,6 @@ class VMobject(Mobject):
def get_group_class(self):
return VGroup
def init_data(self):
super().init_data()
self.data.pop("rgbas")
self.data.update({
"fill_rgba": np.zeros((1, 4)),
"stroke_rgba": np.zeros((1, 4)),
"stroke_width": np.zeros((1, 1)),
"orientation": np.ones((1, 1)),
"joint_angle": np.zeros((0, 1)),
})
def init_uniforms(self):
super().init_uniforms()
self.uniforms["anti_alias_width"] = self.anti_alias_width
@ -371,23 +367,28 @@ class VMobject(Mobject):
If there are multiple colors (for gradient)
this returns the first one
"""
return self.get_fill_colors()[0]
data = self.data if self.has_points() else self._data_defaults
return rgb_to_hex(data["fill_rgba"][0, :3])
def get_fill_opacity(self) -> float:
"""
If there are multiple opacities, this returns the
first
"""
return self.get_fill_opacities()[0]
data = self.data if self.has_points() else self._data_defaults
return data["fill_rgba"][0, 3]
def get_stroke_color(self) -> str:
return self.get_stroke_colors()[0]
data = self.data if self.has_points() else self._data_defaults
return rgb_to_hex(data["stroke_rgba"][0, :3])
def get_stroke_width(self) -> float | np.ndarray:
return self.get_stroke_widths()[0]
data = self.data if self.has_points() else self._data_defaults
return data["stroke_width"][0, 0]
def get_stroke_opacity(self) -> float:
return self.get_stroke_opacities()[0]
data = self.data if self.has_points() else self._data_defaults
return data["stroke_rgba"][0, 3]
def get_color(self) -> str:
if self.has_fill():
@ -1134,7 +1135,7 @@ class VMobject(Mobject):
return self
@triggers_refreshed_triangulation
def set_data(self, data: dict):
def set_data(self, data: np.ndarray):
super().set_data(data)
return self