Refactor LabelledString

This commit is contained in:
YishiMichael
2022-04-15 22:54:06 +08:00
parent 09952756ce
commit 0a810bb4f1
4 changed files with 128 additions and 136 deletions

View File

@ -22,6 +22,11 @@ if TYPE_CHECKING:
ManimColor = Union[str, Color]
Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]
class LabelledString(SVGMobject, ABC):
@ -52,10 +57,10 @@ class LabelledString(SVGMobject, ABC):
self.post_parse()
def get_file_path(self) -> str:
return self.get_file_path_(use_plain_file=True)
return self.get_file_path_(is_labelled=False)
def get_file_path_(self, use_plain_file: bool) -> str:
content = self.get_content(use_plain_file)
def get_file_path_(self, is_labelled: bool) -> str:
content = self.get_content(is_labelled)
return self.get_file_path_by_content(content)
@abstractmethod
@ -67,7 +72,7 @@ class LabelledString(SVGMobject, ABC):
num_labels = len(self.label_span_list)
if num_labels:
file_path = self.get_file_path_(use_plain_file=False)
file_path = self.get_file_path_(is_labelled=True)
labelled_svg = SVGMobject(file_path)
submob_color_ints = [
self.color_to_int(submob.get_fill_color())
@ -132,37 +137,31 @@ class LabelledString(SVGMobject, ABC):
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def finditer(
self, pattern: str, flags: int = 0, **kwargs
) -> Iterable[re.Match]:
return re.compile(pattern, flags).finditer(self.string, **kwargs)
def search(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).search(self.string, **kwargs)
def match(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str, **kwargs) -> list[Span]:
def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]:
if isinstance(pattern, str):
pattern = re.compile(pattern)
return [
match_obj.span()
for match_obj in self.finditer(pattern, **kwargs)
for match_obj in pattern.finditer(self.string, **kwargs)
]
def find_substr(self, substr: str, **kwargs) -> list[Span]:
if not substr:
return []
return self.find_spans(re.escape(substr), **kwargs)
def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
return list(it.chain(*[
self.find_substr(substr, **kwargs)
for substr in remove_list_redundancies(substrs)
]))
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
if isinstance(selector, str):
result = self.find_spans(re.escape(selector))
elif isinstance(selector, re.Pattern):
result = self.find_spans(selector)
else:
span = tuple([
(
min(index, self.string_len)
if index >= 0
else max(index + self.string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(selector, self.full_span)
])
result = [span]
return list(filter(lambda span: span[0] < span[1], result))
@staticmethod
def get_neighbouring_pairs(iterable: list) -> list[tuple]:
@ -345,7 +344,10 @@ class LabelledString(SVGMobject, ABC):
spans = list(it.chain(
self.internal_specified_spans,
self.external_specified_spans,
self.find_substrs(self.isolate)
*[
self.find_spans_by_selector(selector)
for selector in self.isolate
]
))
filtered_spans = list(filter(
lambda span: all([
@ -376,7 +378,7 @@ class LabelledString(SVGMobject, ABC):
)
@abstractmethod
def get_content(self, use_plain_file: bool) -> str:
def get_content(self, is_labelled: bool) -> str:
return ""
# Post-parsing
@ -441,7 +443,7 @@ class LabelledString(SVGMobject, ABC):
def get_submob_groups(self) -> list[VGroup]:
return [submob_group for _, submob_group in self.group_items]
def get_parts_by_group_substr(self, substr: str) -> VGroup:
def select_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[
group
for group_substr, group in self.group_items
@ -488,7 +490,7 @@ class LabelledString(SVGMobject, ABC):
span_begin = next_begin
return result
def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
def select_part_by_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [
label for label, span in enumerate(self.label_span_list)
if any([
@ -503,34 +505,28 @@ class LabelledString(SVGMobject, ABC):
if label in labels
])
def get_parts_by_string(
self, substr: str,
case_sensitive: bool = True, regex: bool = False, **kwargs
) -> VGroup:
flags = 0
if not case_sensitive:
flags |= re.I
pattern = substr if regex else re.escape(substr)
def select_parts(self, selector: Selector, **kwargs) -> VGroup:
return VGroup(*[
self.get_part_by_custom_span(span, **kwargs)
for span in self.find_spans(pattern, flags=flags)
if span[0] < span[1]
self.select_part_by_span(span, **kwargs)
for span in self.find_spans_by_selector(selector)
])
def get_part_by_string(
self, substr: str, index: int = 0, **kwargs
def select_part(
self, selector: Selector, index: int = 0, **kwargs
) -> VMobject:
return self.get_parts_by_string(substr, **kwargs)[index]
return self.select_parts(selector, **kwargs)[index]
def set_color_by_string(self, substr: str, color: ManimColor, **kwargs):
self.get_parts_by_string(substr, **kwargs).set_color(color)
def set_parts_color(
self, selector: Selector, color: ManimColor, **kwargs
):
self.select_parts(selector, **kwargs).set_color(color)
return self
def set_color_by_string_to_color_map(
self, string_to_color_map: dict[str, ManimColor], **kwargs
def set_parts_color_by_dict(
self, color_map: dict[Selector, ManimColor], **kwargs
):
for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color, **kwargs)
for selector, color in color_map.items():
self.set_parts_color(selector, color, **kwargs)
return self
def get_string(self) -> str: