mirror of
https://github.com/3b1b/manim.git
synced 2025-08-02 11:03:03 +08:00
Add better types + Small refactors on space_ops
This commit is contained in:
@ -18,29 +18,31 @@ from manimlib.utils.simple_functions import clip
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, Iterable, Sequence
|
||||
|
||||
import numpy.typing as npt
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3
|
||||
|
||||
|
||||
def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]:
|
||||
return [
|
||||
def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3:
|
||||
return np.array([
|
||||
v1[1] * v2[2] - v1[2] * v2[1],
|
||||
v1[2] * v2[0] - v1[0] * v2[2],
|
||||
v1[0] * v2[1] - v1[1] * v2[0]
|
||||
]
|
||||
])
|
||||
|
||||
|
||||
def get_norm(vect: Iterable) -> float:
|
||||
def get_norm(vect: VectN | List[float]) -> float:
|
||||
return sum((x**2 for x in vect))**0.5
|
||||
|
||||
|
||||
def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarray:
|
||||
def normalize(
|
||||
vect: VectN | List[float],
|
||||
fall_back: VectN | List[float] | None = None
|
||||
) -> VectN:
|
||||
norm = get_norm(vect)
|
||||
if norm > 0:
|
||||
return np.array(vect) / norm
|
||||
elif fall_back is not None:
|
||||
return fall_back
|
||||
return np.array(fall_back)
|
||||
else:
|
||||
return np.zeros(len(vect))
|
||||
|
||||
@ -48,15 +50,18 @@ def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarr
|
||||
# Operations related to rotation
|
||||
|
||||
|
||||
def quaternion_mult(*quats: Sequence[float]) -> list[float]:
|
||||
# Real part is last entry, which is bizzare, but fits scipy Rotation convention
|
||||
def quaternion_mult(*quats: Vect4) -> Vect4:
|
||||
"""
|
||||
Inputs are treated as quaternions, where the real part is the
|
||||
last entry, so as to follow the scipy Rotation conventions.
|
||||
"""
|
||||
if len(quats) == 0:
|
||||
return [0, 0, 0, 1]
|
||||
result = quats[0]
|
||||
return np.array([0, 0, 0, 1])
|
||||
result = np.array(quats[0])
|
||||
for next_quat in quats[1:]:
|
||||
x1, y1, z1, w1 = result
|
||||
x2, y2, z2, w2 = next_quat
|
||||
result = [
|
||||
result[:] = [
|
||||
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
||||
w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2,
|
||||
w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2,
|
||||
@ -67,67 +72,68 @@ def quaternion_mult(*quats: Sequence[float]) -> list[float]:
|
||||
|
||||
def quaternion_from_angle_axis(
|
||||
angle: float,
|
||||
axis: np.ndarray,
|
||||
) -> list[float]:
|
||||
axis: Vect3,
|
||||
) -> Vect4:
|
||||
return Rotation.from_rotvec(angle * normalize(axis)).as_quat()
|
||||
|
||||
|
||||
def angle_axis_from_quaternion(quat: Sequence[float]) -> tuple[float, np.ndarray]:
|
||||
def angle_axis_from_quaternion(quat: Vect4) -> Tuple[float, Vect3]:
|
||||
rot_vec = Rotation.from_quat(quat).as_rotvec()
|
||||
norm = get_norm(rot_vec)
|
||||
return norm, rot_vec / norm
|
||||
|
||||
|
||||
def quaternion_conjugate(quaternion: Iterable) -> list:
|
||||
result = list(quaternion)
|
||||
for i in range(3):
|
||||
result[i] *= -1
|
||||
def quaternion_conjugate(quaternion: Vect4) -> Vect4:
|
||||
result = np.array(quaternion)
|
||||
result[:3] *= -1
|
||||
return result
|
||||
|
||||
|
||||
def rotate_vector(
|
||||
vector: Iterable,
|
||||
vector: Vect3,
|
||||
angle: float,
|
||||
axis: np.ndarray = OUT
|
||||
) -> np.ndarray | list[float]:
|
||||
axis: Vect3 = OUT
|
||||
) -> Vect3:
|
||||
rot = Rotation.from_rotvec(angle * normalize(axis))
|
||||
return np.dot(vector, rot.as_matrix().T)
|
||||
|
||||
|
||||
def rotate_vector_2d(vector: Iterable, angle: float):
|
||||
def rotate_vector_2d(vector: Vect2, angle: float) -> Vect2:
|
||||
# Use complex numbers...because why not
|
||||
z = complex(*vector) * np.exp(complex(0, angle))
|
||||
return np.array([z.real, z.imag])
|
||||
|
||||
|
||||
def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> np.ndarray:
|
||||
def rotation_matrix_transpose_from_quaternion(quat: Vect4) -> Matrix3x3:
|
||||
return Rotation.from_quat(quat).as_matrix()
|
||||
|
||||
|
||||
def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray:
|
||||
def rotation_matrix_from_quaternion(quat: Vect4) -> Matrix3x3:
|
||||
return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
|
||||
|
||||
|
||||
def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray:
|
||||
def rotation_matrix(angle: float, axis: Vect3) -> Matrix3x3:
|
||||
"""
|
||||
Rotation in R^3 about a specified axis of rotation.
|
||||
"""
|
||||
return Rotation.from_rotvec(angle * normalize(axis)).as_matrix()
|
||||
|
||||
|
||||
def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> np.ndarray:
|
||||
def rotation_matrix_transpose(angle: float, axis: Vect3) -> Matrix3x3:
|
||||
return rotation_matrix(angle, axis).T
|
||||
|
||||
|
||||
def rotation_about_z(angle: float) -> list[list[float]]:
|
||||
return [
|
||||
[math.cos(angle), -math.sin(angle), 0],
|
||||
[math.sin(angle), math.cos(angle), 0],
|
||||
def rotation_about_z(angle: float) -> Matrix3x3:
|
||||
cos_a = math.cos(angle)
|
||||
sin_a = math.sin(angle)
|
||||
return np.array([
|
||||
[cos_a, -sin_a, 0],
|
||||
[sin_a, cos_a, 0],
|
||||
[0, 0, 1]
|
||||
]
|
||||
])
|
||||
|
||||
|
||||
def rotation_between_vectors(v1, v2) -> np.ndarray:
|
||||
def rotation_between_vectors(v1: Vect3, v2: Vect3) -> Matrix3x3:
|
||||
if np.all(np.isclose(v1, v2)):
|
||||
return np.identity(3)
|
||||
return rotation_matrix(
|
||||
@ -136,18 +142,18 @@ def rotation_between_vectors(v1, v2) -> np.ndarray:
|
||||
)
|
||||
|
||||
|
||||
def z_to_vector(vector: np.ndarray) -> np.ndarray:
|
||||
def z_to_vector(vector: Vect3) -> Matrix3x3:
|
||||
return rotation_between_vectors(OUT, vector)
|
||||
|
||||
|
||||
def angle_of_vector(vector: Sequence[float]) -> float:
|
||||
def angle_of_vector(vector: Vect2 | Vect3) -> float:
|
||||
"""
|
||||
Returns polar coordinate theta when vector is project on xy plane
|
||||
"""
|
||||
return np.angle(complex(*vector[:2]))
|
||||
|
||||
|
||||
def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
|
||||
def angle_between_vectors(v1: VectN, v2: VectN) -> float:
|
||||
"""
|
||||
Returns the angle between two 3D vectors.
|
||||
This angle will always be btw 0 and pi
|
||||
@ -160,7 +166,7 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
|
||||
return math.acos(clip(cos_angle, -1, 1))
|
||||
|
||||
|
||||
def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray:
|
||||
def project_along_vector(point: Vect3, vector: Vect3) -> Vect3:
|
||||
matrix = np.identity(3) - np.outer(vector, vector)
|
||||
return np.dot(point, matrix.T)
|
||||
|
||||
@ -177,10 +183,10 @@ def normalize_along_axis(
|
||||
|
||||
|
||||
def get_unit_normal(
|
||||
v1: np.ndarray,
|
||||
v2: np.ndarray,
|
||||
v1: Vect3,
|
||||
v2: Vect3,
|
||||
tol: float = 1e-6
|
||||
) -> np.ndarray:
|
||||
) -> Vect3:
|
||||
v1 = normalize(v1)
|
||||
v2 = normalize(v2)
|
||||
cp = cross(v1, v2)
|
||||
@ -204,7 +210,7 @@ def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
|
||||
return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
|
||||
|
||||
|
||||
def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray:
|
||||
def compass_directions(n: int = 4, start_vect: Vect3 = RIGHT) -> Vect3:
|
||||
angle = TAU / n
|
||||
return np.array([
|
||||
rotate_vector(start_vect, k * angle)
|
||||
@ -212,36 +218,32 @@ def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray
|
||||
])
|
||||
|
||||
|
||||
def complex_to_R3(complex_num: complex) -> np.ndarray:
|
||||
def complex_to_R3(complex_num: complex) -> Vect3:
|
||||
return np.array((complex_num.real, complex_num.imag, 0))
|
||||
|
||||
|
||||
def R3_to_complex(point: Sequence[float]) -> complex:
|
||||
def R3_to_complex(point: Vect3) -> complex:
|
||||
return complex(*point[:2])
|
||||
|
||||
|
||||
def complex_func_to_R3_func(
|
||||
complex_func: Callable[[complex], complex]
|
||||
) -> Callable[[np.ndarray], np.ndarray]:
|
||||
return lambda p: complex_to_R3(complex_func(R3_to_complex(p)))
|
||||
def complex_func_to_R3_func(complex_func: Callable[[complex], complex]) -> Callable[[Vect3], Vect3]:
|
||||
def result(p: Vect3):
|
||||
return complex_to_R3(complex_func(R3_to_complex(p)))
|
||||
return result
|
||||
|
||||
|
||||
def center_of_mass(points: Iterable[npt.ArrayLike]) -> np.ndarray:
|
||||
points = [np.array(point).astype("float") for point in points]
|
||||
return sum(points) / len(points)
|
||||
def center_of_mass(points: Sequence[Vect3]) -> Vect3:
|
||||
return np.array(points).sum(0) / len(points)
|
||||
|
||||
|
||||
def midpoint(
|
||||
point1: Sequence[float],
|
||||
point2: Sequence[float]
|
||||
) -> np.ndarray:
|
||||
def midpoint(point1: VectN, point2: VectN) -> VectN:
|
||||
return center_of_mass([point1, point2])
|
||||
|
||||
|
||||
def line_intersection(
|
||||
line1: Sequence[Sequence[float]],
|
||||
line2: Sequence[Sequence[float]]
|
||||
) -> np.ndarray:
|
||||
line1: Tuple[Vect3, Vect3],
|
||||
line2: Tuple[Vect3, Vect3]
|
||||
) -> Vect3:
|
||||
"""
|
||||
return intersection point of two lines,
|
||||
each defined with a pair of vectors determining
|
||||
@ -263,12 +265,12 @@ def line_intersection(
|
||||
|
||||
|
||||
def find_intersection(
|
||||
p0: npt.ArrayLike,
|
||||
v0: npt.ArrayLike,
|
||||
p1: npt.ArrayLike,
|
||||
v1: npt.ArrayLike,
|
||||
p0: Vect3,
|
||||
v0: Vect3,
|
||||
p1: Vect3,
|
||||
v1: Vect3,
|
||||
threshold: float = 1e-5
|
||||
) -> np.ndarray:
|
||||
) -> Vect3:
|
||||
"""
|
||||
Return the intersection of a line passing through p0 in direction v0
|
||||
with one passing through p1 in direction v1. (Or array of intersections
|
||||
@ -300,11 +302,7 @@ def find_intersection(
|
||||
return result
|
||||
|
||||
|
||||
def get_closest_point_on_line(
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
p: np.ndarray
|
||||
) -> np.ndarray:
|
||||
def get_closest_point_on_line(a: VectN, b: VectN, p: VectN) -> VectN:
|
||||
"""
|
||||
It returns point x such that
|
||||
x is on line ab and xp is perpendicular to ab.
|
||||
@ -319,7 +317,7 @@ def get_closest_point_on_line(
|
||||
return ((t * a) + ((1 - t) * b))
|
||||
|
||||
|
||||
def get_winding_number(points: Iterable[float]) -> float:
|
||||
def get_winding_number(points: Sequence[Vect2 | Vect3]) -> float:
|
||||
total_angle = 0
|
||||
for p1, p2 in adjacent_pairs(points):
|
||||
d_angle = angle_of_vector(p2) - angle_of_vector(p1)
|
||||
@ -330,7 +328,7 @@ def get_winding_number(points: Iterable[float]) -> float:
|
||||
|
||||
##
|
||||
|
||||
def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
def cross2d(a: Vect2, b: Vect2) -> Vect2:
|
||||
if len(a.shape) == 2:
|
||||
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
|
||||
else:
|
||||
@ -338,9 +336,9 @@ def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def tri_area(
|
||||
a: Sequence[float],
|
||||
b: Sequence[float],
|
||||
c: Sequence[float]
|
||||
a: Vect2,
|
||||
b: Vect2,
|
||||
c: Vect2
|
||||
) -> float:
|
||||
return 0.5 * abs(
|
||||
a[0] * (b[1] - c[1]) +
|
||||
@ -350,10 +348,10 @@ def tri_area(
|
||||
|
||||
|
||||
def is_inside_triangle(
|
||||
p: np.ndarray,
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
c: np.ndarray
|
||||
p: Vect2,
|
||||
a: Vect2,
|
||||
b: Vect2,
|
||||
c: Vect2
|
||||
) -> bool:
|
||||
"""
|
||||
Test if point p is inside triangle abc
|
||||
@ -363,15 +361,15 @@ def is_inside_triangle(
|
||||
cross2d(p - b, c - p),
|
||||
cross2d(p - c, a - p),
|
||||
])
|
||||
return np.all(crosses > 0) or np.all(crosses < 0)
|
||||
return bool(np.all(crosses > 0) or np.all(crosses < 0))
|
||||
|
||||
|
||||
def norm_squared(v: Sequence[float]) -> float:
|
||||
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
|
||||
def norm_squared(v: VectN | List[float]) -> float:
|
||||
return sum(x * x for x in v)
|
||||
|
||||
|
||||
# TODO, fails for polygons drawn over themselves
|
||||
def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list:
|
||||
def earclip_triangulation(verts: Vect2 | Vect3, ring_ends: list[int]) -> list[int]:
|
||||
"""
|
||||
Returns a list of indices giving a triangulation
|
||||
of a polygon, potentially with holes
|
||||
|
Reference in New Issue
Block a user