Revert some files

This commit is contained in:
YishiMichael
2022-04-22 15:31:13 +08:00
parent 8852921b3d
commit f8c8a399c9
6 changed files with 799 additions and 683 deletions

View File

@ -1,41 +1,30 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import itertools as it
import re
import colour
import itertools as it
from typing import Iterable, Union, Sequence
from abc import ABC, abstractmethod
from manimlib.constants import WHITE
from manimlib.constants import BLACK, WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from colour import Color
from typing import Iterable, Union
ManimColor = Union[str, Color]
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
class LabelledString(SVGMobject, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
class _StringSVG(SVGMobject):
CONFIG = {
"height": None,
"stroke_width": 0,
@ -44,31 +33,42 @@ class LabelledString(SVGMobject, ABC):
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class LabelledString(_StringSVG, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = {
"base_color": WHITE,
"use_plain_file": False,
"isolate": [],
}
def __init__(self, string: str, **kwargs):
self.string = string
digest_config(self, kwargs)
if self.base_color is None:
self.base_color = WHITE
self.base_color_int = self.color_to_int(self.base_color)
self.string_len = len(self.string)
self.full_span = (0, self.string_len)
# Convert `base_color` to hex code.
self.base_color = rgb_to_hex(color_to_rgb(
self.base_color \
or self.svg_default.get("color", None) \
or self.svg_default.get("fill_color", None) \
or WHITE
))
self.svg_default["fill_color"] = BLACK
self.pre_parse()
self.parse()
super().__init__()
self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
self.post_parse()
def get_file_path(self) -> str:
return self.get_file_path_(is_labelled=False)
return self.get_file_path_(use_plain_file=False)
def get_file_path_(self, is_labelled: bool) -> str:
content = self.get_content(is_labelled)
def get_file_path_(self, use_plain_file: bool) -> str:
content = self.get_content(use_plain_file)
return self.get_file_path_by_content(content)
@abstractmethod
@ -78,113 +78,87 @@ class LabelledString(SVGMobject, ABC):
def generate_mobject(self) -> None:
super().generate_mobject()
num_labels = len(self.label_span_list)
if num_labels:
file_path = self.get_file_path_(is_labelled=True)
labelled_svg = SVGMobject(file_path)
submob_color_ints = [
self.color_to_int(submob.get_fill_color())
for submob in labelled_svg.submobjects
]
submob_labels = [
self.color_to_label(submob.get_fill_color())
for submob in self.submobjects
]
if self.use_plain_file or self.has_predefined_local_colors:
file_path = self.get_file_path_(use_plain_file=True)
plain_svg = _StringSVG(
file_path,
svg_default=self.svg_default,
path_string_config=self.path_string_config
)
self.set_submobjects(plain_svg.submobjects)
else:
submob_color_ints = [0] * len(self.submobjects)
self.set_fill(self.base_color)
for submob, label in zip(self.submobjects, submob_labels):
submob.label = label
if len(self.submobjects) != len(submob_color_ints):
raise ValueError(
"Cannot align submobjects of the labelled svg "
"to the original svg"
)
unrecognized_color_ints = remove_list_redundancies(sorted(filter(
lambda color_int: color_int > num_labels,
submob_color_ints
)))
if unrecognized_color_ints:
raise ValueError(
"Unrecognized color label(s) detected: "
f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}"
)
for submob, color_int in zip(self.submobjects, submob_color_ints):
submob.label = color_int - 1
def pre_parse(self) -> None:
self.string_len = len(self.string)
self.full_span = (0, self.string_len)
def parse(self) -> None:
self.skippable_indices = self.get_skippable_indices()
self.command_repl_items = self.get_command_repl_items()
self.command_spans = self.get_command_spans()
self.extra_entity_spans = self.get_extra_entity_spans()
self.entity_spans = self.get_entity_spans()
self.bracket_spans = self.get_bracket_spans()
self.extra_isolated_items = self.get_extra_isolated_items()
self.specified_items = self.get_specified_items()
self.extra_ignored_spans = self.get_extra_ignored_spans()
self.skipped_spans = self.get_skipped_spans()
self.internal_specified_spans = self.get_internal_specified_spans()
self.external_specified_spans = self.get_external_specified_spans()
self.specified_spans = self.get_specified_spans()
self.check_overlapping()
self.label_span_list = self.get_label_span_list()
if len(self.label_span_list) >= 16777216:
raise ValueError("Cannot handle that many substrings")
self.check_overlapping()
def copy(self):
return self.deepcopy()
def post_parse(self) -> None:
self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
self.labelled_submobjects = self.get_labelled_submobjects()
self.specified_substrs = self.get_specified_substrs()
self.group_items = self.get_group_items()
self.group_substrs = self.get_group_substrs()
self.submob_groups = self.get_submob_groups()
# Toolkits
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None:
if isinstance(pattern, str):
pattern = re.compile(pattern)
return re.compile(pattern).match(self.string, **kwargs)
def finditer(
self, pattern: str, flags: int = 0, **kwargs
) -> Iterable[re.Match]:
return re.compile(pattern, flags).finditer(self.string, **kwargs)
def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]:
if isinstance(pattern, str):
pattern = re.compile(pattern)
def search(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).search(self.string, **kwargs)
def match(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str, **kwargs) -> list[Span]:
return [
match_obj.span()
for match_obj in pattern.finditer(self.string, **kwargs)
for match_obj in self.finditer(pattern, **kwargs)
]
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
return [index for index, _ in self.find_spans(pattern, **kwargs)]
def find_substr(self, substr: str, **kwargs) -> list[Span]:
if not substr:
return []
return self.find_spans(re.escape(substr), **kwargs)
@staticmethod
def is_single_selector(selector: Selector) -> bool:
if isinstance(selector, str):
return True
if isinstance(selector, re.Pattern):
return True
if isinstance(selector, tuple):
if len(selector) == 2 and all([
isinstance(index, int) or index is None
for index in selector
]):
return True
return False
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
if self.is_single_selector(selector):
selector = (selector,)
result = []
for sel in selector:
if not self.is_single_selector(sel):
raise TypeError(f"Invalid selector: '{sel}'")
if isinstance(sel, str):
spans = self.find_spans(re.escape(sel))
elif isinstance(sel, re.Pattern):
spans = self.find_spans(sel)
else:
span = tuple([
(
min(index, self.string_len)
if index >= 0
else max(index + self.string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(sel, self.full_span)
])
spans = [span]
result.extend(spans)
return sorted(filter(
lambda span: span[0] < span[1],
remove_list_redundancies(result)
))
def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
return list(it.chain(*[
self.find_substr(substr, **kwargs)
for substr in remove_list_redundancies(substrs)
]))
@staticmethod
def get_neighbouring_pairs(iterable: list) -> list[tuple]:
@ -223,24 +197,41 @@ class LabelledString(SVGMobject, ABC):
spans = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, spans))
@staticmethod
def find_region_index(seq: list[int], val: int) -> int:
# Returns an integer in `range(-1, len(seq))` satisfying
# `seq[result] <= val < seq[result + 1]`.
# `seq` should be sorted in ascending order.
if not seq or val < seq[0]:
return -1
result = len(seq) - 1
while val < seq[result]:
result -= 1
return result
@staticmethod
def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int:
sorted_seq = sorted(seq)
index = LabelledString.find_region_index(sorted_seq, val)
return sorted_seq[index + index_shift]
@staticmethod
def generate_span_repl_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
repl_items: list[tuple[Span, str]]
other_repl_items: list[tuple[Span, str]]
) -> dict[Span, str]:
result = dict(repl_items)
result = dict(other_repl_items)
if not inserted_string_pairs:
return result
indices, _, _, _, inserted_strings = zip(*sorted([
indices, _, _, inserted_strings = zip(*sorted([
(
item[0][flag],
span[flag],
-flag,
-item[0][1 - flag],
(1, -1)[flag] * item_index,
item[1][flag]
-span[1 - flag],
str_pair[flag]
)
for item_index, item in enumerate(inserted_string_pairs)
for span, str_pair in inserted_string_pairs
for flag in range(2)
]))
result.update({
@ -272,74 +263,113 @@ class LabelledString(SVGMobject, ABC):
return "".join(it.chain(*zip(pieces, repl_strs)))
@staticmethod
def color_to_int(color: ManimColor) -> int:
hex_code = rgb_to_hex(color_to_rgb(color))
return int(hex_code[1:], 16)
def rslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted(skipped))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def lslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted([
skipped_span[::-1] for skipped_span in skipped
], reverse=True))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int:
r, g, b = rgb_tuple
rg = r * 256 + g
return rg * 256 + b
@staticmethod
def int_to_rgb(rgb_int: int) -> tuple[int, int, int]:
rg, b = divmod(rgb_int, 256)
r, g = divmod(rg, 256)
return r, g, b
@staticmethod
def int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod
def color_to_label(color: ManimColor) -> int:
rgb_tuple = color_to_int_rgb(color)
rgb = LabelledString.rgb_to_int(rgb_tuple)
return rgb - 1
# Parsing
@abstractmethod
def get_skippable_indices(self) -> list[int]:
def get_command_repl_items(self) -> list[tuple[Span, str]]:
return []
@staticmethod
def shrink_span(span: Span, skippable_indices: list[int]) -> Span:
span_begin, span_end = span
while span_begin in skippable_indices:
span_begin += 1
while span_end - 1 in skippable_indices:
span_end -= 1
return (span_begin, span_end)
def get_command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items]
@abstractmethod
def get_extra_entity_spans(self) -> list[Span]:
return []
def get_entity_spans(self) -> list[Span]:
return []
@abstractmethod
def get_bracket_spans(self) -> list[Span]:
return []
@abstractmethod
def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]:
return []
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
span_items = list(it.chain(
self.extra_isolated_items,
[
(span, {})
for span in self.find_spans_by_selector(self.isolate)
]
return list(it.chain(
self.command_spans,
self.extra_entity_spans
))
result = []
for span, attr_dict in span_items:
shrinked_span = self.shrink_span(span, self.skippable_indices)
if shrinked_span[0] >= shrinked_span[1]:
continue
if any([
entity_span[0] < index < entity_span[1]
for index in shrinked_span
for entity_span in self.entity_spans
]):
continue
result.append((shrinked_span, attr_dict))
return result
@abstractmethod
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_skipped_spans(self) -> list[Span]:
return list(it.chain(
self.find_spans(r"\s"),
self.command_spans,
self.extra_ignored_spans
))
def shrink_span(self, span: Span) -> Span:
return (
self.rslide(span[0], self.skipped_spans),
self.lslide(span[1], self.skipped_spans)
)
@abstractmethod
def get_internal_specified_spans(self) -> list[Span]:
return []
@abstractmethod
def get_external_specified_spans(self) -> list[Span]:
return []
def get_specified_spans(self) -> list[Span]:
return remove_list_redundancies([
span for span, _ in self.specified_items
])
spans = list(it.chain(
self.internal_specified_spans,
self.external_specified_spans,
self.find_substrs(self.isolate)
))
shrinked_spans = list(filter(
lambda span: span[0] < span[1] and not any([
entity_span[0] < index < entity_span[1]
for index in span
for entity_span in self.entity_spans
]),
[self.shrink_span(span) for span in spans]
))
return remove_list_redundancies(shrinked_spans)
@abstractmethod
def get_label_span_list(self) -> list[Span]:
return []
def check_overlapping(self) -> None:
spans = remove_list_redundancies(list(it.chain(
self.specified_spans,
self.bracket_spans
)))
for span_0, span_1 in it.product(spans, repeat=2):
for span_0, span_1 in it.product(self.label_span_list, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
@ -348,20 +378,29 @@ class LabelledString(SVGMobject, ABC):
)
@abstractmethod
def get_label_span_list(self) -> list[Span]:
return []
@abstractmethod
def get_content(self, is_labelled: bool) -> str:
def get_content(self, use_plain_file: bool) -> str:
return ""
# Selector
@abstractmethod
def has_predefined_local_colors(self) -> bool:
return False
# Post-parsing
def get_labelled_submobjects(self) -> list[VMobject]:
return [submob for _, submob in self.labelled_submobject_items]
def get_cleaned_substr(self, span: Span) -> str:
return ""
span_repl_dict = dict.fromkeys(self.command_spans, "")
return self.get_replaced_substr(span, span_repl_dict)
def get_group_part_items(self) -> list[tuple[str, VGroup]]:
def get_specified_substrs(self) -> list[str]:
return remove_list_redundancies([
self.get_cleaned_substr(span)
for span in self.specified_spans
])
def get_group_items(self) -> list[tuple[str, VGroup]]:
if not self.labelled_submobject_items:
return []
@ -386,31 +425,41 @@ class LabelledString(SVGMobject, ABC):
ordered_spans
)
]
group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else ""
shrinked_spans = [
self.shrink_span(span)
for span in self.get_complement_spans(
interval_spans, (ordered_spans[0][0], ordered_spans[-1][1])
)
]
group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in shrinked_spans
]
submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)])
for submob_span in labelled_submob_spans
])
return list(zip(group_substrs, submob_groups))
def get_specified_part_items(self) -> list[tuple[str, VGroup]]:
return [
(
self.get_substr(span),
self.select_part_by_span(span, substring=False)
)
for span in self.specified_spans
]
def get_group_substrs(self) -> list[str]:
return [group_substr for group_substr, _ in self.group_items]
def get_submob_groups(self) -> list[VGroup]:
return [submob_group for _, submob_group in self.group_items]
def get_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[
group
for group_substr, group in self.group_items
if group_substr == substr
])
# Selector
def find_span_components(
self, custom_span: Span, substring: bool = True
) -> list[Span]:
shrinked_span = self.shrink_span(custom_span, self.skippable_indices)
shrinked_span = self.shrink_span(custom_span)
if shrinked_span[0] >= shrinked_span[1]:
return []
@ -419,12 +468,12 @@ class LabelledString(SVGMobject, ABC):
self.full_span,
*self.label_span_list
)))
span_begin = max(filter(
lambda index: index <= shrinked_span[0], indices
))
span_end = min(filter(
lambda index: index >= shrinked_span[1], indices
))
span_begin = self.take_nearest_value(
indices, shrinked_span[0], 0
)
span_end = self.take_nearest_value(
indices, shrinked_span[1] - 1, 1
)
else:
span_begin, span_end = shrinked_span
@ -445,7 +494,7 @@ class LabelledString(SVGMobject, ABC):
span_begin = next_begin
return result
def select_part_by_span(self, custom_span: Span, **kwargs) -> VGroup:
def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [
label for label, span in enumerate(self.label_span_list)
if any([
@ -460,31 +509,34 @@ class LabelledString(SVGMobject, ABC):
if label in labels
])
def select_parts(self, selector: Selector, **kwargs) -> VGroup:
return VGroup(*filter(
lambda part: part.submobjects,
[
self.select_part_by_span(span, **kwargs)
for span in self.find_spans_by_selector(selector)
]
))
def select_part(
self, selector: Selector, index: int = 0, **kwargs
def get_parts_by_string(
self, substr: str,
case_sensitive: bool = True, regex: bool = False, **kwargs
) -> VGroup:
return self.select_parts(selector, **kwargs)[index]
flags = 0
if not case_sensitive:
flags |= re.I
pattern = substr if regex else re.escape(substr)
return VGroup(*[
self.get_part_by_custom_span(span, **kwargs)
for span in self.find_spans(pattern, flags=flags)
if span[0] < span[1]
])
def set_parts_color(
self, selector: Selector, color: ManimColor, **kwargs
):
self.select_parts(selector, **kwargs).set_color(color)
def get_part_by_string(
self, substr: str, index: int = 0, **kwargs
) -> VMobject:
return self.get_parts_by_string(substr, **kwargs)[index]
def set_color_by_string(self, substr: str, color: ManimColor, **kwargs):
self.get_parts_by_string(substr, **kwargs).set_color(color)
return self
def set_parts_color_by_dict(
self, color_map: dict[Selector, ManimColor], **kwargs
def set_color_by_string_to_color_map(
self, string_to_color_map: dict[str, ManimColor], **kwargs
):
for selector, color in color_map.items():
self.set_parts_color(selector, color, **kwargs)
for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color, **kwargs)
return self
def get_string(self) -> str: