Distinguish Vect3 from Vect3Array types

This commit is contained in:
Grant Sanderson
2022-12-17 13:16:48 -08:00
parent 8db20cc460
commit 97f28b34f3
13 changed files with 83 additions and 77 deletions

View File

@ -48,10 +48,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Union, Tuple
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
TimeBasedUpdater = Callable[["Mobject", float], None]
NonTimeUpdater = Callable[["Mobject"], None]
TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
Updater = Union[TimeBasedUpdater, NonTimeUpdater]
@ -233,7 +233,7 @@ class Mobject(object):
self.set_points(mobject.get_points())
return self
def get_points(self) -> Vect3:
def get_points(self) -> Vect3Array:
return self.data["points"]
def clear_points(self) -> None:
@ -242,7 +242,7 @@ class Mobject(object):
def get_num_points(self) -> int:
return len(self.data["points"])
def get_all_points(self) -> Vect3:
def get_all_points(self) -> Vect3Array:
if self.submobjects:
return np.vstack([sm.get_points() for sm in self.get_family()])
else:
@ -251,13 +251,13 @@ class Mobject(object):
def has_points(self) -> bool:
return self.get_num_points() > 0
def get_bounding_box(self) -> Vect3:
def get_bounding_box(self) -> Vect3Array:
if self.needs_new_bounding_box:
self.data["bounding_box"] = self.compute_bounding_box()
self.needs_new_bounding_box = False
return self.data["bounding_box"]
def compute_bounding_box(self) -> Vect3:
def compute_bounding_box(self) -> Vect3Array:
all_points = np.vstack([
self.get_points(),
*(
@ -289,9 +289,9 @@ class Mobject(object):
def are_points_touching(
self,
points: Vect3,
points: Vect3Array,
buff: float = 0
) -> bool:
) -> np.ndarray:
bb = self.get_bounding_box()
mins = (bb[0] - buff)
maxs = (bb[2] + buff)
@ -1871,7 +1871,7 @@ class Mobject(object):
)
return self
def get_resized_shader_data_array(self, length: int) -> Vect3:
def get_resized_shader_data_array(self, length: int) -> np.ndarray:
# If possible, try to populate an existing array, rather
# than recreating it each frame
if len(self.shader_data) != length:
@ -1880,7 +1880,7 @@ class Mobject(object):
def read_data_to_shader(
self,
shader_data: Vect3,
shader_data: np.ndarray,
shader_data_key: str,
data_key: str
):