mirror of
https://github.com/3b1b/manim.git
synced 2025-08-02 19:46:21 +08:00
Small refactors on StringMobject and relevant classes
This commit is contained in:
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from colour import Color
|
||||
from typing import Iterable, Sequence, TypeVar, Union
|
||||
from typing import Callable, Iterable, TypeVar, Union
|
||||
|
||||
ManimColor = Union[str, Color]
|
||||
Span = tuple[int, int]
|
||||
@ -66,6 +66,11 @@ class StringMobject(SVGMobject, ABC):
|
||||
"isolate": (),
|
||||
}
|
||||
|
||||
CMD_PATTERN: str | None = None
|
||||
FLAG_DICT: dict[str, int] = {}
|
||||
CONTENT_REPL: dict[str, str | Callable[[re.Match], str]] = {}
|
||||
MATCH_REPL: dict[str, str | Callable[[re.Match], str]] = {}
|
||||
|
||||
def __init__(self, string: str, **kwargs):
|
||||
self.string = string
|
||||
digest_config(self, kwargs)
|
||||
@ -153,21 +158,18 @@ class StringMobject(SVGMobject, ABC):
|
||||
|
||||
# Toolkits
|
||||
|
||||
def get_substr(self, span: Span) -> str:
|
||||
return self.string[slice(*span)]
|
||||
|
||||
def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
|
||||
return [
|
||||
match_obj.span()
|
||||
for match_obj in re.finditer(pattern, self.string)
|
||||
]
|
||||
|
||||
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
|
||||
def find_spans_by_single_selector(sel):
|
||||
if isinstance(sel, str):
|
||||
return self.find_spans(re.escape(sel))
|
||||
return [
|
||||
match_obj.span()
|
||||
for match_obj in re.finditer(re.escape(sel), self.string)
|
||||
]
|
||||
if isinstance(sel, re.Pattern):
|
||||
return self.find_spans(sel)
|
||||
return [
|
||||
match_obj.span()
|
||||
for match_obj in sel.finditer(self.string)
|
||||
]
|
||||
if isinstance(sel, tuple) and len(sel) == 2 and all(
|
||||
isinstance(index, int) or index is None
|
||||
for index in sel
|
||||
@ -191,24 +193,59 @@ class StringMobject(SVGMobject, ABC):
|
||||
result.extend(spans)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]:
|
||||
return list(zip(vals[:-1], vals[1:]))
|
||||
def get_substr(self, span: Span) -> str:
|
||||
return self.string[slice(*span)]
|
||||
|
||||
@staticmethod
|
||||
def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]:
|
||||
def get_substr_matched_obj(
|
||||
substr: str, match_dict: dict[str, T]
|
||||
) -> tuple[re.Match, T] | None:
|
||||
for pattern, val in match_dict.items():
|
||||
match_obj = re.fullmatch(pattern, substr, re.S)
|
||||
if match_obj is None:
|
||||
continue
|
||||
return match_obj, val
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_substr_matched_val(
|
||||
substr: str, match_dict: dict[str, T], default: T
|
||||
) -> T:
|
||||
obj = StringMobject.get_substr_matched_obj(substr, match_dict)
|
||||
if obj is None:
|
||||
return default
|
||||
_, val = obj
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def get_substr_matched_str(
|
||||
substr: str, match_dict: dict[str, str | Callable[[re.Match], str]]
|
||||
) -> str:
|
||||
obj = StringMobject.get_substr_matched_obj(substr, match_dict)
|
||||
if obj is None:
|
||||
return substr
|
||||
match_obj, val = obj
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
return val(match_obj)
|
||||
|
||||
@staticmethod
|
||||
def get_neighbouring_pairs(vals: Iterable[T]) -> list[tuple[T, T]]:
|
||||
val_list = list(vals)
|
||||
return list(zip(val_list[:-1], val_list[1:]))
|
||||
|
||||
@staticmethod
|
||||
def group_neighbours(vals: Iterable[T]) -> list[tuple[T, Span]]:
|
||||
if not vals:
|
||||
return []
|
||||
|
||||
unique_vals = [vals[0]]
|
||||
indices = [0]
|
||||
for index, val in enumerate(vals):
|
||||
if val == unique_vals[-1]:
|
||||
continue
|
||||
unique_vals.append(val)
|
||||
indices.append(index)
|
||||
indices.append(len(vals))
|
||||
val_ranges = StringMobject.get_neighbouring_pairs(indices)
|
||||
unique_vals, range_lens = zip(*(
|
||||
(val, len(list(grouper)))
|
||||
for val, grouper in it.groupby(vals)
|
||||
))
|
||||
val_ranges = StringMobject.get_neighbouring_pairs(
|
||||
[0, *it.accumulate(range_lens)]
|
||||
)
|
||||
return list(zip(unique_vals, val_ranges))
|
||||
|
||||
@staticmethod
|
||||
@ -228,18 +265,6 @@ class StringMobject(SVGMobject, ABC):
|
||||
(*span_ends, universal_span[1])
|
||||
))
|
||||
|
||||
def replace_substr(self, span: Span, repl_items: list[Span, str]):
|
||||
if not repl_items:
|
||||
return self.get_substr(span)
|
||||
|
||||
repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
|
||||
pieces = [
|
||||
self.get_substr(piece_span)
|
||||
for piece_span in self.get_complement_spans(span, repl_spans)
|
||||
]
|
||||
repl_strs = [*repl_strs, ""]
|
||||
return "".join(it.chain(*zip(pieces, repl_strs)))
|
||||
|
||||
@staticmethod
|
||||
def color_to_hex(color: ManimColor) -> str:
|
||||
return rgb_to_hex(color_to_rgb(color))
|
||||
@ -255,12 +280,26 @@ class StringMobject(SVGMobject, ABC):
|
||||
# Parsing
|
||||
|
||||
def parse(self) -> None:
|
||||
cmd_spans = self.get_cmd_spans()
|
||||
pattern = self.CMD_PATTERN
|
||||
cmd_spans = [] if pattern is None else [
|
||||
match_obj.span()
|
||||
for match_obj in re.finditer(pattern, self.string, re.S)
|
||||
]
|
||||
cmd_substrs = [self.get_substr(span) for span in cmd_spans]
|
||||
flags = [self.get_substr_flag(substr) for substr in cmd_substrs]
|
||||
specified_items = self.get_specified_items(
|
||||
self.get_cmd_span_pairs(cmd_spans, flags)
|
||||
)
|
||||
flags = [
|
||||
self.get_substr_matched_val(substr, self.FLAG_DICT, 0)
|
||||
for substr in cmd_substrs
|
||||
]
|
||||
specified_items = [
|
||||
*self.get_internal_specified_items(
|
||||
self.get_cmd_span_pairs(cmd_spans, flags)
|
||||
),
|
||||
*self.get_external_specified_items(),
|
||||
*[
|
||||
(span, {})
|
||||
for span in self.find_spans_by_selector(self.isolate)
|
||||
]
|
||||
]
|
||||
split_items = [
|
||||
(span, attr_dict)
|
||||
for specified_span, attr_dict in specified_items
|
||||
@ -273,31 +312,15 @@ class StringMobject(SVGMobject, ABC):
|
||||
self.split_items = split_items
|
||||
self.labelled_spans = [span for span, _ in split_items]
|
||||
self.cmd_repl_items_for_content = [
|
||||
(span, self.get_repl_substr_for_content(substr))
|
||||
(span, self.get_substr_matched_str(substr, self.CONTENT_REPL))
|
||||
for span, substr in zip(cmd_spans, cmd_substrs)
|
||||
]
|
||||
self.cmd_repl_items_for_matching = [
|
||||
(span, self.get_repl_substr_for_matching(substr))
|
||||
(span, self.get_substr_matched_str(substr, self.MATCH_REPL))
|
||||
for span, substr in zip(cmd_spans, cmd_substrs)
|
||||
]
|
||||
self.check_overlapping()
|
||||
|
||||
@abstractmethod
|
||||
def get_cmd_spans(self) -> list[Span]:
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_substr_flag(self, substr: str) -> int:
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def get_repl_substr_for_content(self, substr: str) -> str:
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def get_repl_substr_for_matching(self, substr: str) -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_cmd_span_pairs(
|
||||
cmd_spans: list[Span], flags: list[int]
|
||||
@ -317,11 +340,17 @@ class StringMobject(SVGMobject, ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def get_specified_items(
|
||||
def get_internal_specified_items(
|
||||
self, cmd_span_pairs: list[tuple[Span, Span]]
|
||||
) -> list[tuple[Span, dict[str, str]]]:
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_external_specified_items(
|
||||
self
|
||||
) -> list[tuple[Span, dict[str, str]]]:
|
||||
return []
|
||||
|
||||
def split_span_by_levels(
|
||||
self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int]
|
||||
) -> list[Span]:
|
||||
@ -387,6 +416,18 @@ class StringMobject(SVGMobject, ABC):
|
||||
) -> tuple[str, str]:
|
||||
return "", ""
|
||||
|
||||
def replace_substr(self, span: Span, repl_items: list[Span, str]):
|
||||
if not repl_items:
|
||||
return self.get_substr(span)
|
||||
|
||||
repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
|
||||
pieces = [
|
||||
self.get_substr(piece_span)
|
||||
for piece_span in self.get_complement_spans(span, repl_spans)
|
||||
]
|
||||
repl_strs = [*repl_strs, ""]
|
||||
return "".join(it.chain(*zip(pieces, repl_strs)))
|
||||
|
||||
def get_content(self, is_labelled: bool) -> str:
|
||||
inserted_str_pairs = [
|
||||
(span, self.get_cmd_str_pair(
|
||||
@ -446,7 +487,7 @@ class StringMobject(SVGMobject, ABC):
|
||||
return []
|
||||
|
||||
group_labels, labelled_submob_ranges = zip(
|
||||
*self.compress_neighbours(self.labels)
|
||||
*self.group_neighbours(self.labels)
|
||||
)
|
||||
ordered_spans = [
|
||||
self.labelled_spans[label] if label != -1 else self.full_span
|
||||
|
Reference in New Issue
Block a user