Distinguish Vect3 from Vect3Array types

This commit is contained in:
Grant Sanderson
2022-12-17 13:16:48 -08:00
parent 8db20cc460
commit 97f28b34f3
13 changed files with 83 additions and 77 deletions

View File

@ -23,7 +23,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
from manimlib.mobject.coordinate_systems import CoordinateSystem
from manimlib.mobject.mobject import Mobject
@ -35,7 +35,7 @@ def get_vectorized_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
) -> Callable[[npt.ArrayLike], Vect3]:
) -> Callable[[VectN], Vect3Array]:
rgbs = np.array(get_colormap_list(color_map))
def func(values):
@ -57,9 +57,9 @@ def get_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
) -> Callable[[T], Vect3]:
) -> Callable[[float], Vect3]:
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func([value])[0]
return lambda value: vectorized_func(np.array([value]))[0]
def move_along_vector_field(
@ -254,7 +254,7 @@ class StreamLines(VGroup):
lines.append(line)
self.set_submobjects(lines)
def get_start_points(self) -> Vect3:
def get_start_points(self) -> Vect3Array:
cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system(
cs, self.step_multiple,