Add better types + Small refactors on space_ops

This commit is contained in:
Grant Sanderson
2022-12-16 20:35:45 -08:00
parent dec11a4b17
commit cef6506920

View File

@ -18,29 +18,31 @@ from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence
import numpy.typing as npt
from typing import Callable, Sequence, List, Tuple
from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3
def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]:
return [
def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3:
return np.array([
v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2],
v1[0] * v2[1] - v1[1] * v2[0]
]
])
def get_norm(vect: Iterable) -> float:
def get_norm(vect: VectN | List[float]) -> float:
return sum((x**2 for x in vect))**0.5
def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarray:
def normalize(
vect: VectN | List[float],
fall_back: VectN | List[float] | None = None
) -> VectN:
norm = get_norm(vect)
if norm > 0:
return np.array(vect) / norm
elif fall_back is not None:
return fall_back
return np.array(fall_back)
else:
return np.zeros(len(vect))
@ -48,15 +50,18 @@ def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarr
# Operations related to rotation
def quaternion_mult(*quats: Sequence[float]) -> list[float]:
# Real part is last entry, which is bizzare, but fits scipy Rotation convention
def quaternion_mult(*quats: Vect4) -> Vect4:
"""
Inputs are treated as quaternions, where the real part is the
last entry, so as to follow the scipy Rotation conventions.
"""
if len(quats) == 0:
return [0, 0, 0, 1]
result = quats[0]
return np.array([0, 0, 0, 1])
result = np.array(quats[0])
for next_quat in quats[1:]:
x1, y1, z1, w1 = result
x2, y2, z2, w2 = next_quat
result = [
result[:] = [
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2,
w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2,
@ -67,67 +72,68 @@ def quaternion_mult(*quats: Sequence[float]) -> list[float]:
def quaternion_from_angle_axis(
angle: float,
axis: np.ndarray,
) -> list[float]:
axis: Vect3,
) -> Vect4:
return Rotation.from_rotvec(angle * normalize(axis)).as_quat()
def angle_axis_from_quaternion(quat: Sequence[float]) -> tuple[float, np.ndarray]:
def angle_axis_from_quaternion(quat: Vect4) -> Tuple[float, Vect3]:
rot_vec = Rotation.from_quat(quat).as_rotvec()
norm = get_norm(rot_vec)
return norm, rot_vec / norm
def quaternion_conjugate(quaternion: Iterable) -> list:
result = list(quaternion)
for i in range(3):
result[i] *= -1
def quaternion_conjugate(quaternion: Vect4) -> Vect4:
result = np.array(quaternion)
result[:3] *= -1
return result
def rotate_vector(
vector: Iterable,
vector: Vect3,
angle: float,
axis: np.ndarray = OUT
) -> np.ndarray | list[float]:
axis: Vect3 = OUT
) -> Vect3:
rot = Rotation.from_rotvec(angle * normalize(axis))
return np.dot(vector, rot.as_matrix().T)
def rotate_vector_2d(vector: Iterable, angle: float):
def rotate_vector_2d(vector: Vect2, angle: float) -> Vect2:
# Use complex numbers...because why not
z = complex(*vector) * np.exp(complex(0, angle))
return np.array([z.real, z.imag])
def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> np.ndarray:
def rotation_matrix_transpose_from_quaternion(quat: Vect4) -> Matrix3x3:
return Rotation.from_quat(quat).as_matrix()
def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray:
def rotation_matrix_from_quaternion(quat: Vect4) -> Matrix3x3:
return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray:
def rotation_matrix(angle: float, axis: Vect3) -> Matrix3x3:
"""
Rotation in R^3 about a specified axis of rotation.
"""
return Rotation.from_rotvec(angle * normalize(axis)).as_matrix()
def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> np.ndarray:
def rotation_matrix_transpose(angle: float, axis: Vect3) -> Matrix3x3:
return rotation_matrix(angle, axis).T
def rotation_about_z(angle: float) -> list[list[float]]:
return [
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
def rotation_about_z(angle: float) -> Matrix3x3:
cos_a = math.cos(angle)
sin_a = math.sin(angle)
return np.array([
[cos_a, -sin_a, 0],
[sin_a, cos_a, 0],
[0, 0, 1]
]
])
def rotation_between_vectors(v1, v2) -> np.ndarray:
def rotation_between_vectors(v1: Vect3, v2: Vect3) -> Matrix3x3:
if np.all(np.isclose(v1, v2)):
return np.identity(3)
return rotation_matrix(
@ -136,18 +142,18 @@ def rotation_between_vectors(v1, v2) -> np.ndarray:
)
def z_to_vector(vector: np.ndarray) -> np.ndarray:
def z_to_vector(vector: Vect3) -> Matrix3x3:
return rotation_between_vectors(OUT, vector)
def angle_of_vector(vector: Sequence[float]) -> float:
def angle_of_vector(vector: Vect2 | Vect3) -> float:
"""
Returns polar coordinate theta when vector is project on xy plane
"""
return np.angle(complex(*vector[:2]))
def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
def angle_between_vectors(v1: VectN, v2: VectN) -> float:
"""
Returns the angle between two 3D vectors.
This angle will always be btw 0 and pi
@ -160,7 +166,7 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
return math.acos(clip(cos_angle, -1, 1))
def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray:
def project_along_vector(point: Vect3, vector: Vect3) -> Vect3:
matrix = np.identity(3) - np.outer(vector, vector)
return np.dot(point, matrix.T)
@ -177,10 +183,10 @@ def normalize_along_axis(
def get_unit_normal(
v1: np.ndarray,
v2: np.ndarray,
v1: Vect3,
v2: Vect3,
tol: float = 1e-6
) -> np.ndarray:
) -> Vect3:
v1 = normalize(v1)
v2 = normalize(v2)
cp = cross(v1, v2)
@ -204,7 +210,7 @@ def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray:
def compass_directions(n: int = 4, start_vect: Vect3 = RIGHT) -> Vect3:
angle = TAU / n
return np.array([
rotate_vector(start_vect, k * angle)
@ -212,36 +218,32 @@ def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray
])
def complex_to_R3(complex_num: complex) -> np.ndarray:
def complex_to_R3(complex_num: complex) -> Vect3:
return np.array((complex_num.real, complex_num.imag, 0))
def R3_to_complex(point: Sequence[float]) -> complex:
def R3_to_complex(point: Vect3) -> complex:
return complex(*point[:2])
def complex_func_to_R3_func(
complex_func: Callable[[complex], complex]
) -> Callable[[np.ndarray], np.ndarray]:
return lambda p: complex_to_R3(complex_func(R3_to_complex(p)))
def complex_func_to_R3_func(complex_func: Callable[[complex], complex]) -> Callable[[Vect3], Vect3]:
def result(p: Vect3):
return complex_to_R3(complex_func(R3_to_complex(p)))
return result
def center_of_mass(points: Iterable[npt.ArrayLike]) -> np.ndarray:
points = [np.array(point).astype("float") for point in points]
return sum(points) / len(points)
def center_of_mass(points: Sequence[Vect3]) -> Vect3:
return np.array(points).sum(0) / len(points)
def midpoint(
point1: Sequence[float],
point2: Sequence[float]
) -> np.ndarray:
def midpoint(point1: VectN, point2: VectN) -> VectN:
return center_of_mass([point1, point2])
def line_intersection(
line1: Sequence[Sequence[float]],
line2: Sequence[Sequence[float]]
) -> np.ndarray:
line1: Tuple[Vect3, Vect3],
line2: Tuple[Vect3, Vect3]
) -> Vect3:
"""
return intersection point of two lines,
each defined with a pair of vectors determining
@ -263,12 +265,12 @@ def line_intersection(
def find_intersection(
p0: npt.ArrayLike,
v0: npt.ArrayLike,
p1: npt.ArrayLike,
v1: npt.ArrayLike,
p0: Vect3,
v0: Vect3,
p1: Vect3,
v1: Vect3,
threshold: float = 1e-5
) -> np.ndarray:
) -> Vect3:
"""
Return the intersection of a line passing through p0 in direction v0
with one passing through p1 in direction v1. (Or array of intersections
@ -300,11 +302,7 @@ def find_intersection(
return result
def get_closest_point_on_line(
a: np.ndarray,
b: np.ndarray,
p: np.ndarray
) -> np.ndarray:
def get_closest_point_on_line(a: VectN, b: VectN, p: VectN) -> VectN:
"""
It returns point x such that
x is on line ab and xp is perpendicular to ab.
@ -319,7 +317,7 @@ def get_closest_point_on_line(
return ((t * a) + ((1 - t) * b))
def get_winding_number(points: Iterable[float]) -> float:
def get_winding_number(points: Sequence[Vect2 | Vect3]) -> float:
total_angle = 0
for p1, p2 in adjacent_pairs(points):
d_angle = angle_of_vector(p2) - angle_of_vector(p1)
@ -330,7 +328,7 @@ def get_winding_number(points: Iterable[float]) -> float:
##
def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
def cross2d(a: Vect2, b: Vect2) -> Vect2:
if len(a.shape) == 2:
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
else:
@ -338,9 +336,9 @@ def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
def tri_area(
a: Sequence[float],
b: Sequence[float],
c: Sequence[float]
a: Vect2,
b: Vect2,
c: Vect2
) -> float:
return 0.5 * abs(
a[0] * (b[1] - c[1]) +
@ -350,10 +348,10 @@ def tri_area(
def is_inside_triangle(
p: np.ndarray,
a: np.ndarray,
b: np.ndarray,
c: np.ndarray
p: Vect2,
a: Vect2,
b: Vect2,
c: Vect2
) -> bool:
"""
Test if point p is inside triangle abc
@ -363,15 +361,15 @@ def is_inside_triangle(
cross2d(p - b, c - p),
cross2d(p - c, a - p),
])
return np.all(crosses > 0) or np.all(crosses < 0)
return bool(np.all(crosses > 0) or np.all(crosses < 0))
def norm_squared(v: Sequence[float]) -> float:
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
def norm_squared(v: VectN | List[float]) -> float:
return sum(x * x for x in v)
# TODO, fails for polygons drawn over themselves
def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list:
def earclip_triangulation(verts: Vect2 | Vect3, ring_ends: list[int]) -> list[int]:
"""
Returns a list of indices giving a triangulation
of a polygon, potentially with holes