Small refactors on StringMobject and relevant classes

This commit is contained in:
YishiMichael
2022-05-28 21:43:37 +08:00
parent 59eba943e5
commit f0447d7739
9 changed files with 669 additions and 737 deletions

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from colour import Color
from typing import Iterable, Sequence, TypeVar, Union
from typing import Callable, Iterable, TypeVar, Union
ManimColor = Union[str, Color]
Span = tuple[int, int]
@ -66,6 +66,11 @@ class StringMobject(SVGMobject, ABC):
"isolate": (),
}
CMD_PATTERN: str | None = None
FLAG_DICT: dict[str, int] = {}
CONTENT_REPL: dict[str, str | Callable[[re.Match], str]] = {}
MATCH_REPL: dict[str, str | Callable[[re.Match], str]] = {}
def __init__(self, string: str, **kwargs):
self.string = string
digest_config(self, kwargs)
@ -153,21 +158,18 @@ class StringMobject(SVGMobject, ABC):
# Toolkits
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
return [
match_obj.span()
for match_obj in re.finditer(pattern, self.string)
]
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
def find_spans_by_single_selector(sel):
if isinstance(sel, str):
return self.find_spans(re.escape(sel))
return [
match_obj.span()
for match_obj in re.finditer(re.escape(sel), self.string)
]
if isinstance(sel, re.Pattern):
return self.find_spans(sel)
return [
match_obj.span()
for match_obj in sel.finditer(self.string)
]
if isinstance(sel, tuple) and len(sel) == 2 and all(
isinstance(index, int) or index is None
for index in sel
@ -191,24 +193,59 @@ class StringMobject(SVGMobject, ABC):
result.extend(spans)
return result
@staticmethod
def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]:
return list(zip(vals[:-1], vals[1:]))
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
@staticmethod
def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]:
def get_substr_matched_obj(
substr: str, match_dict: dict[str, T]
) -> tuple[re.Match, T] | None:
for pattern, val in match_dict.items():
match_obj = re.fullmatch(pattern, substr, re.S)
if match_obj is None:
continue
return match_obj, val
return None
@staticmethod
def get_substr_matched_val(
substr: str, match_dict: dict[str, T], default: T
) -> T:
obj = StringMobject.get_substr_matched_obj(substr, match_dict)
if obj is None:
return default
_, val = obj
return val
@staticmethod
def get_substr_matched_str(
substr: str, match_dict: dict[str, str | Callable[[re.Match], str]]
) -> str:
obj = StringMobject.get_substr_matched_obj(substr, match_dict)
if obj is None:
return substr
match_obj, val = obj
if isinstance(val, str):
return val
return val(match_obj)
@staticmethod
def get_neighbouring_pairs(vals: Iterable[T]) -> list[tuple[T, T]]:
val_list = list(vals)
return list(zip(val_list[:-1], val_list[1:]))
@staticmethod
def group_neighbours(vals: Iterable[T]) -> list[tuple[T, Span]]:
if not vals:
return []
unique_vals = [vals[0]]
indices = [0]
for index, val in enumerate(vals):
if val == unique_vals[-1]:
continue
unique_vals.append(val)
indices.append(index)
indices.append(len(vals))
val_ranges = StringMobject.get_neighbouring_pairs(indices)
unique_vals, range_lens = zip(*(
(val, len(list(grouper)))
for val, grouper in it.groupby(vals)
))
val_ranges = StringMobject.get_neighbouring_pairs(
[0, *it.accumulate(range_lens)]
)
return list(zip(unique_vals, val_ranges))
@staticmethod
@ -228,18 +265,6 @@ class StringMobject(SVGMobject, ABC):
(*span_ends, universal_span[1])
))
def replace_substr(self, span: Span, repl_items: list[Span, str]):
if not repl_items:
return self.get_substr(span)
repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
pieces = [
self.get_substr(piece_span)
for piece_span in self.get_complement_spans(span, repl_spans)
]
repl_strs = [*repl_strs, ""]
return "".join(it.chain(*zip(pieces, repl_strs)))
@staticmethod
def color_to_hex(color: ManimColor) -> str:
return rgb_to_hex(color_to_rgb(color))
@ -255,12 +280,26 @@ class StringMobject(SVGMobject, ABC):
# Parsing
def parse(self) -> None:
cmd_spans = self.get_cmd_spans()
pattern = self.CMD_PATTERN
cmd_spans = [] if pattern is None else [
match_obj.span()
for match_obj in re.finditer(pattern, self.string, re.S)
]
cmd_substrs = [self.get_substr(span) for span in cmd_spans]
flags = [self.get_substr_flag(substr) for substr in cmd_substrs]
specified_items = self.get_specified_items(
self.get_cmd_span_pairs(cmd_spans, flags)
)
flags = [
self.get_substr_matched_val(substr, self.FLAG_DICT, 0)
for substr in cmd_substrs
]
specified_items = [
*self.get_internal_specified_items(
self.get_cmd_span_pairs(cmd_spans, flags)
),
*self.get_external_specified_items(),
*[
(span, {})
for span in self.find_spans_by_selector(self.isolate)
]
]
split_items = [
(span, attr_dict)
for specified_span, attr_dict in specified_items
@ -273,31 +312,15 @@ class StringMobject(SVGMobject, ABC):
self.split_items = split_items
self.labelled_spans = [span for span, _ in split_items]
self.cmd_repl_items_for_content = [
(span, self.get_repl_substr_for_content(substr))
(span, self.get_substr_matched_str(substr, self.CONTENT_REPL))
for span, substr in zip(cmd_spans, cmd_substrs)
]
self.cmd_repl_items_for_matching = [
(span, self.get_repl_substr_for_matching(substr))
(span, self.get_substr_matched_str(substr, self.MATCH_REPL))
for span, substr in zip(cmd_spans, cmd_substrs)
]
self.check_overlapping()
@abstractmethod
def get_cmd_spans(self) -> list[Span]:
return []
@abstractmethod
def get_substr_flag(self, substr: str) -> int:
return 0
@abstractmethod
def get_repl_substr_for_content(self, substr: str) -> str:
return ""
@abstractmethod
def get_repl_substr_for_matching(self, substr: str) -> str:
return ""
@staticmethod
def get_cmd_span_pairs(
cmd_spans: list[Span], flags: list[int]
@ -317,11 +340,17 @@ class StringMobject(SVGMobject, ABC):
return result
@abstractmethod
def get_specified_items(
def get_internal_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]:
return []
@abstractmethod
def get_external_specified_items(
self
) -> list[tuple[Span, dict[str, str]]]:
return []
def split_span_by_levels(
self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int]
) -> list[Span]:
@ -387,6 +416,18 @@ class StringMobject(SVGMobject, ABC):
) -> tuple[str, str]:
return "", ""
def replace_substr(self, span: Span, repl_items: list[Span, str]):
if not repl_items:
return self.get_substr(span)
repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
pieces = [
self.get_substr(piece_span)
for piece_span in self.get_complement_spans(span, repl_spans)
]
repl_strs = [*repl_strs, ""]
return "".join(it.chain(*zip(pieces, repl_strs)))
def get_content(self, is_labelled: bool) -> str:
inserted_str_pairs = [
(span, self.get_cmd_str_pair(
@ -446,7 +487,7 @@ class StringMobject(SVGMobject, ABC):
return []
group_labels, labelled_submob_ranges = zip(
*self.compress_neighbours(self.labels)
*self.group_neighbours(self.labels)
)
ordered_spans = [
self.labelled_spans[label] if label != -1 else self.full_span