diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 26000972..51b124cf 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -15,9 +15,7 @@ if TYPE_CHECKING: from typing import Callable, Sequence, TypeVar from manimlib.typing import VectN, FloatArray - T = TypeVar("T") - - Scalable = TypeVar("Scalable", float, VectN) + Scalable = TypeVar("Scalable", float, FloatArray) CLOSED_THRESHOLD = 0.001 diff --git a/manimlib/utils/simple_functions.py b/manimlib/utils/simple_functions.py index 143bf350..1236b44a 100644 --- a/manimlib/utils/simple_functions.py +++ b/manimlib/utils/simple_functions.py @@ -5,26 +5,34 @@ import math import numpy as np +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Callable, TypeVar + from manimlib.typing import FloatArray -def sigmoid(x): + Scalable = TypeVar("Scalable", float, FloatArray) + + + +def sigmoid(x: float | FloatArray): return 1.0 / (1 + np.exp(-x)) @lru_cache(maxsize=10) -def choose(n, k): +def choose(n: int, k: int) -> int: return math.comb(n, k) -def gen_choose(n, r): - return np.prod(np.arange(n, n - r, -1)) / math.factorial(r) +def gen_choose(n: int, r: int) -> int: + return int(np.prod(range(n, n - r, -1)) / math.factorial(r)) -def get_num_args(function): +def get_num_args(function: Callable) -> int: return len(get_parameters(function)) -def get_parameters(function): - return inspect.signature(function).parameters +def get_parameters(function: Callable) -> list: + return list(inspect.signature(function).parameters.keys()) # Just to have a less heavyweight name for this extremely common operation # @@ -33,7 +41,7 @@ def get_parameters(function): # but for now, we just allow the option to handle indeterminate 0/0. -def clip(a, min_a, max_a): +def clip(a: float, min_a: float, max_a: float) -> float: if a < min_a: return min_a elif a > max_a: @@ -41,7 +49,7 @@ def clip(a, min_a, max_a): return a -def fdiv(a, b, zero_over_zero_value=None): +def fdiv(a: Scalable, b: Scalable, zero_over_zero_value: Scalable | None = None) -> Scalable: if zero_over_zero_value is not None: out = np.full_like(a, zero_over_zero_value) where = np.logical_or(a != 0, b != 0) @@ -52,15 +60,15 @@ def fdiv(a, b, zero_over_zero_value=None): return np.true_divide(a, b, out=out, where=where) -def binary_search(function, - target, - lower_bound, - upper_bound, - tolerance=1e-4): +def binary_search(function: Callable[[float], float], + target: float, + lower_bound: float, + upper_bound: float, + tolerance:float = 1e-4) -> float | None: lh = lower_bound rh = upper_bound + mh = (lh + rh) / 2 while abs(rh - lh) > tolerance: - mh = np.mean([lh, rh]) lx, mx, rx = [function(h) for h in (lh, mh, rh)] if lx == target: return lx @@ -76,10 +84,11 @@ def binary_search(function, lh, rh = rh, lh else: return None + mh = (lh + rh) / 2 return mh -def hash_string(string): +def hash_string(string: str) -> str: # Truncating at 16 bytes for cleanliness hasher = hashlib.sha256(string.encode()) return hasher.hexdigest()[:16]