diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 6ad6a9bd..27460899 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,12 @@ from __future__ import annotations -from abc import ABC, abstractmethod +import itertools as it +from abc import abstractmethod import numpy as np from manimlib.animation.animation import Animation from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Mobject + from manimlib.mobject.mobject import Group -class ShowPartial(Animation, ABC): +class ShowPartial(Animation): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Mobject, **kwargs): + def __init__(self, group: Group, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) @@ -213,9 +213,7 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = VGroup(*[ - part for _, part in string_mobject.get_group_part_items() - ]) + grouped_mobject = string_mobject.submob_groups digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index e84f1d9d..dab88005 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,9 +5,9 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject @@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config +from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.svg.tex_mobject import SingleStringTex - from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene + from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): @@ -168,36 +168,36 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] + source_indices = list(range(len(source.labelled_submobjects))) + target_indices = list(range(len(target.labelled_submobjects))) - source_submobs = [ - submob for _, submob in source.labelled_submobject_items - ] - target_submobs = [ - submob for _, submob in target.labelled_submobject_items - ] - source_indices = list(range(len(source_submobs))) - target_indices = list(range(len(target_submobs))) - - def get_filtered_indices_lists(parts, submobs, rest_indices): - return list(filter( - lambda indices_list: all([ - index in rest_indices - for index in indices_list - ]), + def get_indices_lists(mobject, parts): + return [ [ - [submobs.index(submob) for submob in part] - for part in parts + mobject.labelled_submobjects.index(submob) + for submob in part ] - )) + for part in parts + ] - def add_anims(anim_class, parts_pairs): - for source_parts, target_parts in parts_pairs: - source_indices_lists = get_filtered_indices_lists( - source_parts, source_submobs, source_indices - ) - target_indices_lists = get_filtered_indices_lists( - target_parts, target_submobs, target_indices - ) + def add_anims_from(anim_class, func, source_args, target_args=None): + if target_args is None: + target_args = source_args.copy() + for source_arg, target_arg in zip(source_args, target_args): + source_parts = func(source, source_arg) + target_parts = func(target, target_arg) + source_indices_lists = list(filter( + lambda indices_list: all([ + index in source_indices + for index in indices_list + ]), get_indices_lists(source, source_parts) + )) + target_indices_lists = list(filter( + lambda indices_list: all([ + index in target_indices + for index in indices_list + ]), get_indices_lists(target, target_parts) + )) if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) @@ -206,45 +206,41 @@ class TransformMatchingStrings(AnimationGroup): for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_substr_to_parts_map(part_items): - result = {} - for substr, part in part_items: - if substr not in result: - result[substr] = [] - result[substr].append(part) + def get_common_substrs(substrs_from_source, substrs_from_target): + return sorted([ + substr for substr in substrs_from_source + if substr and substr in substrs_from_target + ], key=len, reverse=True) + + def get_parts_from_keys(mobject, keys): + if isinstance(keys, str): + keys = [keys] + result = VGroup() + for key in keys: + if not isinstance(key, str): + raise TypeError(key) + result.add(*mobject.get_parts_by_string(key)) return result - def add_anims_from(anim_class, func): - source_substr_to_parts_map = get_substr_to_parts_map(func(source)) - target_substr_to_parts_map = get_substr_to_parts_map(func(target)) - add_anims( - anim_class, - [ - ( - VGroup(*source_substr_to_parts_map[substr]), - VGroup(*target_substr_to_parts_map[substr]) - ) - for substr in sorted([ - s for s in source_substr_to_parts_map.keys() - if s and s in target_substr_to_parts_map.keys() - ], key=len, reverse=True) - ] + add_anims_from( + ReplacementTransform, get_parts_from_keys, + self.key_map.keys(), self.key_map.values() + ) + add_anims_from( + FadeTransformPieces, + LabelledString.get_parts_by_string, + get_common_substrs( + source.specified_substrs, + target.specified_substrs ) - - add_anims( - ReplacementTransform, - [ - (source.select_parts(k), target.select_parts(v)) - for k, v in self.key_map.items() - ] ) add_anims_from( FadeTransformPieces, - LabelledString.get_specified_part_items - ) - add_anims_from( - FadeTransformPieces, - LabelledString.get_group_part_items + LabelledString.get_parts_by_group_substr, + get_common_substrs( + source.group_substrs, + target.group_substrs + ) ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 55b8fca6..f1354f0c 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -1,41 +1,30 @@ from __future__ import annotations -from abc import ABC, abstractmethod -import itertools as it import re +import colour +import itertools as it +from typing import Iterable, Union, Sequence +from abc import ABC, abstractmethod -from manimlib.constants import WHITE +from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies + from typing import TYPE_CHECKING if TYPE_CHECKING: - from colour import Color - from typing import Iterable, Union - - ManimColor = Union[str, Color] + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] - Selector = Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]], - Iterable[Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]] - ]] - ] -class LabelledString(SVGMobject, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ +class _StringSVG(SVGMobject): CONFIG = { "height": None, "stroke_width": 0, @@ -44,31 +33,42 @@ class LabelledString(SVGMobject, ABC): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, + } + + +class LabelledString(_StringSVG, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ + CONFIG = { "base_color": WHITE, + "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string digest_config(self, kwargs) - if self.base_color is None: - self.base_color = WHITE - self.base_color_int = self.color_to_int(self.base_color) - self.string_len = len(self.string) - self.full_span = (0, self.string_len) + # Convert `base_color` to hex code. + self.base_color = rgb_to_hex(color_to_rgb( + self.base_color \ + or self.svg_default.get("color", None) \ + or self.svg_default.get("fill_color", None) \ + or WHITE + )) + self.svg_default["fill_color"] = BLACK + + self.pre_parse() self.parse() super().__init__() - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] + self.post_parse() def get_file_path(self) -> str: - return self.get_file_path_(is_labelled=False) + return self.get_file_path_(use_plain_file=False) - def get_file_path_(self, is_labelled: bool) -> str: - content = self.get_content(is_labelled) + def get_file_path_(self, use_plain_file: bool) -> str: + content = self.get_content(use_plain_file) return self.get_file_path_by_content(content) @abstractmethod @@ -78,113 +78,87 @@ class LabelledString(SVGMobject, ABC): def generate_mobject(self) -> None: super().generate_mobject() - num_labels = len(self.label_span_list) - if num_labels: - file_path = self.get_file_path_(is_labelled=True) - labelled_svg = SVGMobject(file_path) - submob_color_ints = [ - self.color_to_int(submob.get_fill_color()) - for submob in labelled_svg.submobjects - ] + submob_labels = [ + self.color_to_label(submob.get_fill_color()) + for submob in self.submobjects + ] + if self.use_plain_file or self.has_predefined_local_colors: + file_path = self.get_file_path_(use_plain_file=True) + plain_svg = _StringSVG( + file_path, + svg_default=self.svg_default, + path_string_config=self.path_string_config + ) + self.set_submobjects(plain_svg.submobjects) else: - submob_color_ints = [0] * len(self.submobjects) + self.set_fill(self.base_color) + for submob, label in zip(self.submobjects, submob_labels): + submob.label = label - if len(self.submobjects) != len(submob_color_ints): - raise ValueError( - "Cannot align submobjects of the labelled svg " - "to the original svg" - ) - - unrecognized_color_ints = remove_list_redundancies(sorted(filter( - lambda color_int: color_int > num_labels, - submob_color_ints - ))) - if unrecognized_color_ints: - raise ValueError( - "Unrecognized color label(s) detected: " - f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" - ) - - for submob, color_int in zip(self.submobjects, submob_color_ints): - submob.label = color_int - 1 + def pre_parse(self) -> None: + self.string_len = len(self.string) + self.full_span = (0, self.string_len) def parse(self) -> None: - self.skippable_indices = self.get_skippable_indices() + self.command_repl_items = self.get_command_repl_items() + self.command_spans = self.get_command_spans() + self.extra_entity_spans = self.get_extra_entity_spans() self.entity_spans = self.get_entity_spans() - self.bracket_spans = self.get_bracket_spans() - self.extra_isolated_items = self.get_extra_isolated_items() - self.specified_items = self.get_specified_items() + self.extra_ignored_spans = self.get_extra_ignored_spans() + self.skipped_spans = self.get_skipped_spans() + self.internal_specified_spans = self.get_internal_specified_spans() + self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() - self.check_overlapping() self.label_span_list = self.get_label_span_list() - if len(self.label_span_list) >= 16777216: - raise ValueError("Cannot handle that many substrings") + self.check_overlapping() - def copy(self): - return self.deepcopy() + def post_parse(self) -> None: + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] + self.labelled_submobjects = self.get_labelled_submobjects() + self.specified_substrs = self.get_specified_substrs() + self.group_items = self.get_group_items() + self.group_substrs = self.get_group_substrs() + self.submob_groups = self.get_submob_groups() # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: - if isinstance(pattern, str): - pattern = re.compile(pattern) - return re.compile(pattern).match(self.string, **kwargs) + def finditer( + self, pattern: str, flags: int = 0, **kwargs + ) -> Iterable[re.Match]: + return re.compile(pattern, flags).finditer(self.string, **kwargs) - def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: - if isinstance(pattern, str): - pattern = re.compile(pattern) + def search( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: + return re.compile(pattern, flags).search(self.string, **kwargs) + + def match( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: + return re.compile(pattern, flags).match(self.string, **kwargs) + + def find_spans(self, pattern: str, **kwargs) -> list[Span]: return [ match_obj.span() - for match_obj in pattern.finditer(self.string, **kwargs) + for match_obj in self.finditer(pattern, **kwargs) ] - def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: - return [index for index, _ in self.find_spans(pattern, **kwargs)] + def find_substr(self, substr: str, **kwargs) -> list[Span]: + if not substr: + return [] + return self.find_spans(re.escape(substr), **kwargs) - @staticmethod - def is_single_selector(selector: Selector) -> bool: - if isinstance(selector, str): - return True - if isinstance(selector, re.Pattern): - return True - if isinstance(selector, tuple): - if len(selector) == 2 and all([ - isinstance(index, int) or index is None - for index in selector - ]): - return True - return False - - def find_spans_by_selector(self, selector: Selector) -> list[Span]: - if self.is_single_selector(selector): - selector = (selector,) - result = [] - for sel in selector: - if not self.is_single_selector(sel): - raise TypeError(f"Invalid selector: '{sel}'") - if isinstance(sel, str): - spans = self.find_spans(re.escape(sel)) - elif isinstance(sel, re.Pattern): - spans = self.find_spans(sel) - else: - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(sel, self.full_span) - ]) - spans = [span] - result.extend(spans) - return sorted(filter( - lambda span: span[0] < span[1], - remove_list_redundancies(result) - )) + def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: + return list(it.chain(*[ + self.find_substr(substr, **kwargs) + for substr in remove_list_redundancies(substrs) + ])) @staticmethod def get_neighbouring_pairs(iterable: list) -> list[tuple]: @@ -223,24 +197,41 @@ class LabelledString(SVGMobject, ABC): spans = LabelledString.get_neighbouring_pairs(indices) return list(zip(unique_vals, spans)) + @staticmethod + def find_region_index(seq: list[int], val: int) -> int: + # Returns an integer in `range(-1, len(seq))` satisfying + # `seq[result] <= val < seq[result + 1]`. + # `seq` should be sorted in ascending order. + if not seq or val < seq[0]: + return -1 + result = len(seq) - 1 + while val < seq[result]: + result -= 1 + return result + + @staticmethod + def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: + sorted_seq = sorted(seq) + index = LabelledString.find_region_index(sorted_seq, val) + return sorted_seq[index + index_shift] + @staticmethod def generate_span_repl_dict( inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - repl_items: list[tuple[Span, str]] + other_repl_items: list[tuple[Span, str]] ) -> dict[Span, str]: - result = dict(repl_items) + result = dict(other_repl_items) if not inserted_string_pairs: return result - indices, _, _, _, inserted_strings = zip(*sorted([ + indices, _, _, inserted_strings = zip(*sorted([ ( - item[0][flag], + span[flag], -flag, - -item[0][1 - flag], - (1, -1)[flag] * item_index, - item[1][flag] + -span[1 - flag], + str_pair[flag] ) - for item_index, item in enumerate(inserted_string_pairs) + for span, str_pair in inserted_string_pairs for flag in range(2) ])) result.update({ @@ -272,74 +263,113 @@ class LabelledString(SVGMobject, ABC): return "".join(it.chain(*zip(pieces, repl_strs))) @staticmethod - def color_to_int(color: ManimColor) -> int: - hex_code = rgb_to_hex(color_to_rgb(color)) - return int(hex_code[1:], 16) + def rslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted(skipped)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def lslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted([ + skipped_span[::-1] for skipped_span in skipped + ], reverse=True)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: + r, g, b = rgb_tuple + rg = r * 256 + g + return rg * 256 + b + + @staticmethod + def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return r, g, b @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) + + @staticmethod + def color_to_label(color: ManimColor) -> int: + rgb_tuple = color_to_int_rgb(color) + rgb = LabelledString.rgb_to_int(rgb_tuple) + return rgb - 1 + # Parsing @abstractmethod - def get_skippable_indices(self) -> list[int]: + def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] - @staticmethod - def shrink_span(span: Span, skippable_indices: list[int]) -> Span: - span_begin, span_end = span - while span_begin in skippable_indices: - span_begin += 1 - while span_end - 1 in skippable_indices: - span_end -= 1 - return (span_begin, span_end) + def get_command_spans(self) -> list[Span]: + return [cmd_span for cmd_span, _ in self.command_repl_items] @abstractmethod + def get_extra_entity_spans(self) -> list[Span]: + return [] + def get_entity_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_bracket_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - return [] - - def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: - span_items = list(it.chain( - self.extra_isolated_items, - [ - (span, {}) - for span in self.find_spans_by_selector(self.isolate) - ] + return list(it.chain( + self.command_spans, + self.extra_entity_spans )) - result = [] - for span, attr_dict in span_items: - shrinked_span = self.shrink_span(span, self.skippable_indices) - if shrinked_span[0] >= shrinked_span[1]: - continue - if any([ - entity_span[0] < index < entity_span[1] - for index in shrinked_span - for entity_span in self.entity_spans - ]): - continue - result.append((shrinked_span, attr_dict)) - return result + + @abstractmethod + def get_extra_ignored_spans(self) -> list[int]: + return [] + + def get_skipped_spans(self) -> list[Span]: + return list(it.chain( + self.find_spans(r"\s"), + self.command_spans, + self.extra_ignored_spans + )) + + def shrink_span(self, span: Span) -> Span: + return ( + self.rslide(span[0], self.skipped_spans), + self.lslide(span[1], self.skipped_spans) + ) + + @abstractmethod + def get_internal_specified_spans(self) -> list[Span]: + return [] + + @abstractmethod + def get_external_specified_spans(self) -> list[Span]: + return [] def get_specified_spans(self) -> list[Span]: - return remove_list_redundancies([ - span for span, _ in self.specified_items - ]) + spans = list(it.chain( + self.internal_specified_spans, + self.external_specified_spans, + self.find_substrs(self.isolate) + )) + shrinked_spans = list(filter( + lambda span: span[0] < span[1] and not any([ + entity_span[0] < index < entity_span[1] + for index in span + for entity_span in self.entity_spans + ]), + [self.shrink_span(span) for span in spans] + )) + return remove_list_redundancies(shrinked_spans) + + @abstractmethod + def get_label_span_list(self) -> list[Span]: + return [] def check_overlapping(self) -> None: - spans = remove_list_redundancies(list(it.chain( - self.specified_spans, - self.bracket_spans - ))) - for span_0, span_1 in it.product(spans, repeat=2): + for span_0, span_1 in it.product(self.label_span_list, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -348,20 +378,29 @@ class LabelledString(SVGMobject, ABC): ) @abstractmethod - def get_label_span_list(self) -> list[Span]: - return [] - - @abstractmethod - def get_content(self, is_labelled: bool) -> str: + def get_content(self, use_plain_file: bool) -> str: return "" - # Selector - @abstractmethod + def has_predefined_local_colors(self) -> bool: + return False + + # Post-parsing + + def get_labelled_submobjects(self) -> list[VMobject]: + return [submob for _, submob in self.labelled_submobject_items] + def get_cleaned_substr(self, span: Span) -> str: - return "" + span_repl_dict = dict.fromkeys(self.command_spans, "") + return self.get_replaced_substr(span, span_repl_dict) - def get_group_part_items(self) -> list[tuple[str, VGroup]]: + def get_specified_substrs(self) -> list[str]: + return remove_list_redundancies([ + self.get_cleaned_substr(span) + for span in self.specified_spans + ]) + + def get_group_items(self) -> list[tuple[str, VGroup]]: if not self.labelled_submobject_items: return [] @@ -386,31 +425,41 @@ class LabelledString(SVGMobject, ABC): ordered_spans ) ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" + shrinked_spans = [ + self.shrink_span(span) for span in self.get_complement_spans( interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" + for span in shrinked_spans + ] submob_groups = VGroup(*[ VGroup(*labelled_submobjects[slice(*submob_span)]) for submob_span in labelled_submob_spans ]) return list(zip(group_substrs, submob_groups)) - def get_specified_part_items(self) -> list[tuple[str, VGroup]]: - return [ - ( - self.get_substr(span), - self.select_part_by_span(span, substring=False) - ) - for span in self.specified_spans - ] + def get_group_substrs(self) -> list[str]: + return [group_substr for group_substr, _ in self.group_items] + + def get_submob_groups(self) -> list[VGroup]: + return [submob_group for _, submob_group in self.group_items] + + def get_parts_by_group_substr(self, substr: str) -> VGroup: + return VGroup(*[ + group + for group_substr, group in self.group_items + if group_substr == substr + ]) + + # Selector def find_span_components( self, custom_span: Span, substring: bool = True ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span, self.skippable_indices) + shrinked_span = self.shrink_span(custom_span) if shrinked_span[0] >= shrinked_span[1]: return [] @@ -419,12 +468,12 @@ class LabelledString(SVGMobject, ABC): self.full_span, *self.label_span_list ))) - span_begin = max(filter( - lambda index: index <= shrinked_span[0], indices - )) - span_end = min(filter( - lambda index: index >= shrinked_span[1], indices - )) + span_begin = self.take_nearest_value( + indices, shrinked_span[0], 0 + ) + span_end = self.take_nearest_value( + indices, shrinked_span[1] - 1, 1 + ) else: span_begin, span_end = shrinked_span @@ -445,7 +494,7 @@ class LabelledString(SVGMobject, ABC): span_begin = next_begin return result - def select_part_by_span(self, custom_span: Span, **kwargs) -> VGroup: + def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -460,31 +509,34 @@ class LabelledString(SVGMobject, ABC): if label in labels ]) - def select_parts(self, selector: Selector, **kwargs) -> VGroup: - return VGroup(*filter( - lambda part: part.submobjects, - [ - self.select_part_by_span(span, **kwargs) - for span in self.find_spans_by_selector(selector) - ] - )) - - def select_part( - self, selector: Selector, index: int = 0, **kwargs + def get_parts_by_string( + self, substr: str, + case_sensitive: bool = True, regex: bool = False, **kwargs ) -> VGroup: - return self.select_parts(selector, **kwargs)[index] + flags = 0 + if not case_sensitive: + flags |= re.I + pattern = substr if regex else re.escape(substr) + return VGroup(*[ + self.get_part_by_custom_span(span, **kwargs) + for span in self.find_spans(pattern, flags=flags) + if span[0] < span[1] + ]) - def set_parts_color( - self, selector: Selector, color: ManimColor, **kwargs - ): - self.select_parts(selector, **kwargs).set_color(color) + def get_part_by_string( + self, substr: str, index: int = 0, **kwargs + ) -> VMobject: + return self.get_parts_by_string(substr, **kwargs)[index] + + def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): + self.get_parts_by_string(substr, **kwargs).set_color(color) return self - def set_parts_color_by_dict( - self, color_map: dict[Selector, ManimColor], **kwargs + def set_color_by_string_to_color_map( + self, string_to_color_map: dict[str, ManimColor], **kwargs ): - for selector, color in color_map.items(): - self.set_parts_color(selector, color, **kwargs) + for substr, color in string_to_color_map.items(): + self.set_color_by_string(substr, color, **kwargs) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index d4f502f4..fb7922e1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,47 +1,27 @@ from __future__ import annotations import itertools as it -import re +import colour +from typing import Union, Sequence from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import display_during_execution -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import tex_to_svg_file +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import display_during_execution + from typing import TYPE_CHECKING if TYPE_CHECKING: - from colour import Color - from typing import Iterable, Union - + from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup - - ManimColor = Union[str, Color] + ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] - Selector = Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]], - Iterable[Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]] - ]] - ] SCALE_FACTOR_PER_FONT_POINT = 0.001 -TEX_COLOR_COMMANDS_DICT = { - "\\color": (1, False), - "\\textcolor": (1, False), - "\\pagecolor": (1, True), - "\\colorbox": (1, True), - "\\fcolorbox": (2, True), -} - - class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -52,7 +32,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string.strip(): + if not tex_string: tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -67,6 +47,7 @@ class MTex(LabelledString): self.svg_default, self.path_string_config, self.base_color, + self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -80,87 +61,85 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.string}\""): + with display_during_execution(f"Writing \"{self.tex_string}\""): file_path = tex_to_svg_file(full_tex) return file_path - def parse(self) -> None: + def pre_parse(self) -> None: + super().pre_parse() self.backslash_indices = self.get_backslash_indices() - self.command_spans = self.get_command_spans() - self.brace_spans = self.get_brace_spans() - self.script_char_indices = self.get_script_char_indices() + self.brace_index_pairs = self.get_brace_index_pairs() + self.script_char_spans = self.get_script_char_spans() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() - self.command_repl_items = self.get_command_repl_items() - super().parse() # Toolkits @staticmethod def get_color_command_str(rgb_int: int) -> str: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return f"\\color[RGB]{{{r}, {g}, {b}}}" + rgb_tuple = MTex.int_to_rgb(rgb_int) + return "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, rgb_tuple)), + "}" + ]) - # Parsing + # Pre-parsing def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. - return self.find_indices(r"\\.") + return list(it.chain(*[ + range(span[0], span[1], 2) + for span in self.find_spans(r"\\+") + ])) - def get_command_spans(self) -> list[Span]: - return [ - self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices - ] - - def get_unescaped_char_indices(self, char: str) -> list[int]: - return list(filter( - lambda index: index - 1 not in self.backslash_indices, - self.find_indices(re.escape(char)) + def get_unescaped_char_spans(self, chars: str): + return sorted(filter( + lambda span: span[0] - 1 not in self.backslash_indices, + self.find_substrs(list(chars)) )) - def get_brace_spans(self) -> list[Span]: - span_begins = [] - span_ends = [] - span_begins_stack = [] - char_items = sorted([ - (index, char) - for char in "{}" - for index in self.get_unescaped_char_indices(char) - ]) - for index, char in char_items: - if char == "{": - span_begins_stack.append(index) + def get_brace_index_pairs(self) -> list[Span]: + left_brace_indices = [] + right_brace_indices = [] + left_brace_indices_stack = [] + for span in self.get_unescaped_char_spans("{}"): + index = span[0] + if self.get_substr(span) == "{": + left_brace_indices_stack.append(index) else: - if not span_begins_stack: + if not left_brace_indices_stack: raise ValueError("Missing '{' inserted") - span_begins.append(span_begins_stack.pop()) - span_ends.append(index + 1) - if span_begins_stack: + left_brace_index = left_brace_indices_stack.pop() + left_brace_indices.append(left_brace_index) + right_brace_indices.append(index) + if left_brace_indices_stack: raise ValueError("Missing '}' inserted") - return list(zip(span_begins, span_ends)) + return list(zip(left_brace_indices, right_brace_indices)) - def get_script_char_indices(self) -> list[int]: - return list(it.chain(*[ - self.get_unescaped_char_indices(char) - for char in "_^" - ])) + def get_script_char_spans(self) -> list[int]: + return self.get_unescaped_char_spans("_^") def get_script_content_spans(self) -> list[Span]: result = [] - script_entity_dict = dict(it.chain( - self.brace_spans, - self.command_spans - )) - for index in self.script_char_indices: - span_begin = self.match(r"\s*", pos=index + 1).end() - if span_begin in script_entity_dict.keys(): - span_end = script_entity_dict[span_begin] + brace_indices_dict = dict(self.brace_index_pairs) + script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" + for script_char_span in self.script_char_spans: + span_begin = self.match(r"\s*", pos=script_char_span[1]).end() + if span_begin in brace_indices_dict.keys(): + span_end = brace_indices_dict[span_begin] + 1 else: - match_obj = self.match(r".", pos=span_begin) - if match_obj is None: - continue + match_obj = self.match(script_pattern, pos=span_begin) + if not match_obj: + script_name = { + "_": "subscript", + "^": "superscript" + }[script_char] + raise ValueError( + f"Unclear {script_name} detected while parsing. " + "Please use braces to clarify" + ) span_end = match_obj.end() result.append((span_begin, span_end)) return result @@ -168,100 +147,110 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), + self.search(r"\s*$", endpos=script_char_span[0]).start(), script_content_span[1] ) - for index, script_content_span in zip( - self.script_char_indices, self.script_content_spans + for script_char_span, script_content_span in zip( + self.script_char_spans, self.script_content_spans ) ] + # Parsing + def get_command_repl_items(self) -> list[tuple[Span, str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } result = [] - brace_spans_dict = dict(self.brace_spans) - brace_begins = list(brace_spans_dict.keys()) - for cmd_span in self.command_spans: - cmd_name = self.get_substr(cmd_span) - if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): + backslash_indices = self.backslash_indices + right_brace_indices = [ + right_index + for left_index, right_index in self.brace_index_pairs + ] + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in self.finditer(pattern): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: continue - n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] - span_begin, span_end = cmd_span - for _ in range(n_braces): - span_end = brace_spans_dict[min(filter( - lambda index: index >= span_end, - brace_begins - ))] + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = self.take_nearest_value( + right_brace_indices, cmd_end, n_braces + ) + 1 if substitute_cmd: - repl_str = cmd_name + n_braces * "{black}" + repl_str = "\\" + cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - def get_skippable_indices(self) -> list[int]: - return list(it.chain( - self.find_indices(r"\s"), - self.script_char_indices - )) + def get_extra_entity_spans(self) -> list[Span]: + return [ + self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] - def get_entity_spans(self) -> list[Span]: - return self.command_spans.copy() - - def get_bracket_spans(self) -> list[Span]: - return self.brace_spans.copy() - - def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - result = [] + def get_extra_ignored_spans(self) -> list[int]: + return self.script_char_spans.copy() + def get_internal_specified_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). - sorted_brace_spans = sorted( - self.brace_spans, key=lambda span: span[1] - ) + result = [] + reversed_brace_indices_dict = dict([ + pair[::-1] for pair in self.brace_index_pairs + ]) skip = False - for prev_span, span in self.get_neighbouring_pairs( - sorted_brace_spans + for prev_right_index, right_index in self.get_neighbouring_pairs( + list(reversed_brace_indices_dict.keys()) ): if skip: skip = False continue - if span[0] != prev_span[0] - 1 or span[1] != prev_span[1] + 1: + if right_index != prev_right_index + 1: continue - result.append(span) + left_index = reversed_brace_indices_dict[right_index] + prev_left_index = reversed_brace_indices_dict[prev_right_index] + if left_index != prev_left_index - 1: + continue + result.append((left_index, right_index + 1)) skip = True + return result - result.extend(it.chain(*[ - self.find_spans_by_selector(selector) - for selector in self.tex_to_color_map.keys() - ])) - return [(span, {}) for span in result] + def get_external_specified_spans(self) -> list[Span]: + return self.find_substrs(list(self.tex_to_color_map.keys())) def get_label_span_list(self) -> list[Span]: result = self.script_content_spans.copy() - reversed_script_spans_dict = dict([ - script_span[::-1] for script_span in self.script_spans - ]) for span_begin, span_end in self.specified_spans: - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - if span_begin >= span_end: + shrinked_end = self.lslide(span_end, self.script_spans) + if span_begin >= shrinked_end: continue - shrinked_span = (span_begin, span_end) + shrinked_span = (span_begin, shrinked_end) if shrinked_span in result: continue result.append(shrinked_span) return result - def get_content(self, is_labelled: bool) -> str: - if is_labelled: - extended_label_span_list = [] - script_spans_dict = dict(self.script_spans) - for span in self.label_span_list: - if span not in self.script_content_spans: - span_begin, span_end = span - while span_end in script_spans_dict.keys(): - span_end = script_spans_dict[span_end] - span = (span_begin, span_end) - extended_label_span_list.append(span) + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: + span_repl_dict = {} + else: + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] inserted_string_pairs = [ (span, ( "{{" + self.get_color_command_str(label + 1), @@ -270,51 +259,42 @@ class MTex(LabelledString): for label, span in enumerate(extended_label_span_list) ] span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, self.command_repl_items + inserted_string_pairs, + self.command_repl_items ) - else: - span_repl_dict = {} result = self.get_replaced_substr(self.full_span, span_repl_dict) if self.tex_environment: - if isinstance(self.tex_environment, str): - prefix = f"\\begin{{{self.tex_environment}}}" - suffix = f"\\end{{{self.tex_environment}}}" - else: - prefix, suffix = self.tex_environment - result = "\n".join([prefix, result, suffix]) + result = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + result, + f"\\end{{{self.tex_environment}}}" + ]) if self.alignment: result = "\n".join([self.alignment, result]) - if not is_labelled: + if use_plain_file: result = "\n".join([ - self.get_color_command_str(self.base_color_int), + self.get_color_command_str(self.hex_to_int(self.base_color)), result ]) return result - # Selector + @property + def has_predefined_local_colors(self) -> bool: + return bool(self.command_repl_items) + + # Post-parsing def get_cleaned_substr(self, span: Span) -> str: - if not self.brace_spans: - brace_begins, brace_ends = [], [] - else: - brace_begins, brace_ends = zip(*self.brace_spans) - left_brace_indices = list(brace_begins) - right_brace_indices = [index - 1 for index in brace_ends] - skippable_indices = list(it.chain( - self.skippable_indices, - left_brace_indices, - right_brace_indices - )) - shrinked_span = self.shrink_span(span, skippable_indices) - - if shrinked_span[0] >= shrinked_span[1]: - return "" + substr = super().get_cleaned_substr(span) + if not self.brace_index_pairs: + return substr # Balance braces. + left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*shrinked_span): + for index in range(*span): if index in left_brace_indices: unclosed_left_braces += 1 elif index in right_brace_indices: @@ -324,27 +304,27 @@ class MTex(LabelledString): unclosed_left_braces -= 1 return "".join([ unclosed_right_braces * "{", - self.get_substr(shrinked_span), + substr, unclosed_left_braces * "}" ]) # Method alias - def get_parts_by_tex(self, selector: Selector, **kwargs) -> VGroup: - return self.select_parts(selector, **kwargs) + def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: + return self.get_parts_by_string(tex, **kwargs) - def get_part_by_tex(self, selector: Selector, **kwargs) -> VGroup: - return self.select_part(selector, **kwargs) + def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: + return self.get_part_by_string(tex, **kwargs) - def set_color_by_tex( - self, selector: Selector, color: ManimColor, **kwargs - ): - return self.set_parts_color(selector, color, **kwargs) + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): + return self.set_color_by_string(tex, color, **kwargs) def set_color_by_tex_to_color_map( - self, color_map: dict[Selector, ManimColor], **kwargs + self, tex_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_parts_color_by_dict(color_map, **kwargs) + return self.set_color_by_string_to_color_map( + tex_to_color_map, **kwargs + ) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 13f5b9c0..c3c3be19 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,64 +1,93 @@ from __future__ import annotations -from contextlib import contextmanager -import itertools as it import os -from pathlib import Path import re +import itertools as it +from pathlib import Path +from contextlib import contextmanager +import typing +from typing import Iterable, Sequence, Union -import manimpango import pygments import pygments.formatters import pygments.lexers -from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH -from manimlib.constants import NORMAL +from manimpango import MarkupUtils + from manimlib.logger import log +from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.config_ops import digest_config from manimlib.utils.customization import get_customization +from manimlib.utils.tex_file_writing import tex_hash +from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.tex_file_writing import tex_hash +from manimlib.utils.iterables import remove_list_redundancies + from typing import TYPE_CHECKING if TYPE_CHECKING: - from colour import Color - from typing import Iterable, Union - + from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup - - ManimColor = Union[str, Color] + ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] - Selector = Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]], - Iterable[Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]] - ]] - ] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 -# Ensure the canvas is large enough to hold all glyphs. -DEFAULT_CANVAS_WIDTH = 16384 -DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -MARKUP_COLOR_KEYS = ( - "foreground", "fgcolor", "color", - "background", "bgcolor", - "underline_color", - "overline_color", - "strikethrough_color" +# A tag containing two aliases will cause warning, +# so only use the first key of each group of aliases. +SPAN_ATTR_KEY_ALIAS_LIST = ( + ("font", "font_desc"), + ("font_family", "face"), + ("font_size", "size"), + ("font_style", "style"), + ("font_weight", "weight"), + ("font_variant", "variant"), + ("font_stretch", "stretch"), + ("font_features",), + ("foreground", "fgcolor", "color"), + ("background", "bgcolor"), + ("alpha", "fgalpha"), + ("background_alpha", "bgalpha"), + ("underline",), + ("underline_color",), + ("overline",), + ("overline_color",), + ("rise",), + ("baseline_shift",), + ("font_scale",), + ("strikethrough",), + ("strikethrough_color",), + ("fallback",), + ("lang",), + ("letter_spacing",), + ("gravity",), + ("gravity_hint",), + ("show",), + ("insert_hyphens",), + ("allow_breaks",), + ("line_height",), + ("text_transform",), + ("segment",), ) -MARKUP_TAG_CONVERSION_DICT = { +COLOR_RELATED_KEYS = ( + "foreground", + "background", + "underline_color", + "overline_color", + "strikethrough_color" +) +SPAN_ATTR_KEY_CONVERSION = { + key: key_alias_list[0] + for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST + for key in key_alias_list +} +TAG_TO_ATTR_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -91,7 +120,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width": None, + "line_width_factor": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -112,7 +141,9 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - self.validate_markup_string(text) + validate_error = MarkupUtils.validate(text) + if validate_error: + raise ValueError(validate_error) self.text = text super().__init__(text, **kwargs) @@ -134,6 +165,7 @@ class MarkupText(LabelledString): self.svg_default, self.path_string_config, self.base_color, + self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -142,7 +174,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width, + self.line_width_factor, self.font, self.slant, self.weight, @@ -169,32 +201,23 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: - hash_content = str(( - content, - self.justify, - self.indent, - self.alignment, - self.line_width - )) svg_file = os.path.join( - get_text_dir(), tex_hash(hash_content) + ".svg" + get_text_dir(), tex_hash(content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: - self.validate_markup_string(markup_str) - # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width is None: + if self.line_width_factor is None: pango_width = -1 else: - pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH - return manimpango.MarkupUtils.text2svg( + return MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -205,8 +228,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_CANVAS_WIDTH, - height=DEFAULT_CANVAS_HEIGHT, + width=DEFAULT_PIXEL_WIDTH, + height=DEFAULT_PIXEL_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -214,22 +237,13 @@ class MarkupText(LabelledString): pango_width=pango_width ) - @staticmethod - def validate_markup_string(markup_str: str) -> None: - validate_error = manimpango.MarkupUtils.validate(markup_str) - if not validate_error: - return - raise ValueError( - f"Invalid markup string \"{markup_str}\"\n" - f"{validate_error}" - ) - - def parse(self) -> None: - self.global_attr_dict = self.get_global_attr_dict() - self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() - self.tag_spans = self.get_tag_spans() - self.items_from_markup = self.get_items_from_markup() - super().parse() + def pre_parse(self) -> None: + super().pre_parse() + self.tag_items_from_markup = self.get_tag_items_from_markup() + self.global_dict_from_config = self.get_global_dict_from_config() + self.local_dicts_from_markup = self.get_local_dicts_from_markup() + self.local_dicts_from_config = self.get_local_dicts_from_config() + self.predefined_attr_dicts = self.get_predefined_attr_dicts() # Toolkits @@ -240,46 +254,87 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - # Parsing + @staticmethod + def merge_attr_dicts( + attr_dict_items: list[Span, str, typing.Any] + ) -> list[tuple[Span, dict[str, str]]]: + index_seq = [0] + attr_dict_list = [{}] + for span, attr_dict in attr_dict_items: + if span[0] >= span[1]: + continue + region_indices = [ + MarkupText.find_region_index(index_seq, index) + for index in span + ] + for flag in (1, 0): + if index_seq[region_indices[flag]] == span[flag]: + continue + region_index = region_indices[flag] + index_seq.insert(region_index + 1, span[flag]) + attr_dict_list.insert( + region_index + 1, attr_dict_list[region_index].copy() + ) + region_indices[flag] += 1 + if flag == 0: + region_indices[1] += 1 + for key, val in attr_dict.items(): + if not key: + continue + for mid_dict in attr_dict_list[slice(*region_indices)]: + mid_dict[key] = val + return list(zip( + MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] + )) - def get_global_attr_dict(self) -> dict[str, str]: - result = { - "font_size": str(self.font_size * 1024), - "foreground": self.int_to_hex(self.base_color_int), - "font_family": self.font, - "font_style": self.slant, - "font_weight": self.weight, - } - # `line_height` attribute is supported since Pango 1.50. - if tuple(map(int, manimpango.pango_version().split("."))) >= (1, 50): - result.update({ - "line_height": str(( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6), - }) - return result + def find_substr_or_span( + self, substr_or_span: str | tuple[int | None, int | None] + ) -> list[Span]: + if isinstance(substr_or_span, str): + return self.find_substr(substr_or_span) - def get_tag_pairs_from_markup( + span = tuple([ + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(substr_or_span, self.full_span) + ]) + if span[0] >= span[1]: + return [] + return [span] + + # Pre-parsing + + def get_tag_items_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" + tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in re.finditer(tag_pattern, self.string): + for match_obj in self.finditer(tag_pattern): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: match_obj_pairs.append( (begin_match_obj_stack.pop(), match_obj) ) + if begin_match_obj_stack: + raise ValueError("Unclosed tag(s) detected") result = [] for begin_match_obj, end_match_obj in match_obj_pairs: tag_name = begin_match_obj.group(2) + if tag_name != end_match_obj.group(2): + raise ValueError("Unmatched tag names") + if end_match_obj.group(3): + raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { match.group(1): match.group(3) @@ -287,157 +342,189 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } + elif tag_name in TAG_TO_ATTR_DICT.keys(): + if begin_match_obj.group(3): + raise ValueError( + f"Attributes shan't exist in tag '{tag_name}'" + ) + attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() else: - attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) + raise ValueError(f"Unknown tag: '{tag_name}'") result.append( (begin_match_obj.span(), end_match_obj.span(), attr_dict) ) return result - def get_tag_spans(self) -> list[Span]: - return [ - tag_span - for begin_tag, end_tag, _ in self.tag_pairs_from_markup - for tag_span in (begin_tag, end_tag) - ] - - def get_items_from_markup(self) -> list[Span]: - return [ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) - for begin_tag_span, end_tag_span, attr_dict - in self.tag_pairs_from_markup - ] - - def get_skippable_indices(self) -> list[int]: - return self.find_indices(r"\s") - - def get_entity_spans(self) -> list[Span]: - result = self.tag_spans.copy() - if self.is_markup: - result.extend(self.find_spans(r"&[\s\S]*?;")) + def get_global_dict_from_config(self) -> dict[str, typing.Any]: + result = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font, + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + result.update(self.global_config) return result - def get_bracket_spans(self) -> list[Span]: - return [span for span, _ in self.items_from_markup] + def get_local_dicts_from_markup( + self + ) -> list[Span, dict[str, str]]: + return sorted([ + ((begin_tag_span[0], end_tag_span[1]), attr_dict) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_items_from_markup + ]) - def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - return list(it.chain( - self.items_from_markup, - [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for selector, val in t2x_dict.items() - for span in self.find_spans_by_selector(selector) - ], - [ - (span, local_config) - for selector, local_config in self.local_configs.items() - for span in self.find_spans_by_selector(selector) - ] - )) - - def get_label_span_list(self) -> list[Span]: - interval_spans = sorted(it.chain( - self.tag_spans, - [ - (index, index) - for span in self.specified_spans - for index in span - ] - )) - text_spans = self.get_complement_spans(interval_spans, self.full_span) - if self.is_markup: - pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" - else: - pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" - return list(it.chain(*[ - self.find_spans(pattern, pos=span_begin, endpos=span_end) - for span_begin, span_end in text_spans - ])) - - def get_content(self, is_labelled: bool) -> str: - predefined_items = [ - (self.full_span, self.global_attr_dict), - (self.full_span, self.global_config), - *self.specified_items + def get_local_dicts_from_config( + self + ) -> list[Span, dict[str, typing.Any]]: + return [ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for substr_or_span, val in t2x_dict.items() + for span in self.find_substr_or_span(substr_or_span) + ] + [ + (span, local_config) + for substr_or_span, local_config in self.local_configs.items() + for span in self.find_substr_or_span(substr_or_span) ] - if is_labelled: - attr_dict_items = list(it.chain( - [ - (span, { - key: - "black" if key.lower() in MARKUP_COLOR_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in predefined_items - ], - [ - (span, {"foreground": self.int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) - ] - )) - else: - attr_dict_items = list(it.chain( - predefined_items, - [ - (span, {}) - for span in self.label_span_list - ] - )) - inserted_string_pairs = [ - (span, ( - f"", - "" - )) - for span, attr_dict in attr_dict_items if attr_dict + + def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: + attr_dict_items = [ + (self.full_span, self.global_dict_from_config), + *self.local_dicts_from_markup, + *self.local_dicts_from_config ] - repl_items = [ - (tag_span, "") for tag_span in self.tag_spans + return [ + (span, { + SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) + for key, val in attr_dict.items() + }) + for span, attr_dict in attr_dict_items + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + result = [ + (tag_span, "") + for begin_tag, end_tag, _ in self.tag_items_from_markup + for tag_span in (begin_tag, end_tag) ] if not self.is_markup: - repl_items.extend([ + result += [ (span, escaped) for char, escaped in ( ("&", "&"), (">", ">"), ("<", "<") ) - for span in self.find_spans(re.escape(char)) - ]) + for span in self.find_substr(char) + ] + return result + + def get_extra_entity_spans(self) -> list[Span]: + if not self.is_markup: + return [] + return self.find_spans(r"&.*?;") + + def get_extra_ignored_spans(self) -> list[int]: + return [] + + def get_internal_specified_spans(self) -> list[Span]: + return [span for span, _ in self.local_dicts_from_markup] + + def get_external_specified_spans(self) -> list[Span]: + return [span for span, _ in self.local_dicts_from_config] + + def get_label_span_list(self) -> list[Span]: + breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( + self.find_spans(r"\s+"), + self.find_spans(r"\b"), + self.specified_spans + )))) + breakup_indices = sorted(filter( + lambda index: not any([ + span[0] < index < span[1] + for span in self.entity_spans + ]), + breakup_indices + )) + return list(filter( + lambda span: self.get_substr(span).strip(), + self.get_neighbouring_pairs(breakup_indices) + )) + + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: + attr_dict_items = [ + (self.full_span, {"foreground": self.base_color}), + *self.predefined_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] + ] + else: + attr_dict_items = [ + (self.full_span, {"foreground": BLACK}), + *[ + (span, { + key: BLACK if key in COLOR_RELATED_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.predefined_attr_dicts + ], + *[ + (span, {"foreground": self.int_to_hex(label + 1)}) + for label, span in enumerate(self.label_span_list) + ] + ] + inserted_string_pairs = [ + (span, ( + f"", + "" + )) + for span, attr_dict in self.merge_attr_dicts(attr_dict_items) + ] span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, repl_items + inserted_string_pairs, self.command_repl_items ) return self.get_replaced_substr(self.full_span, span_repl_dict) - # Selector - - def get_cleaned_substr(self, span: Span) -> str: - repl_dict = dict.fromkeys(self.tag_spans, "") - return self.get_replaced_substr(span, repl_dict).strip() + @property + def has_predefined_local_colors(self) -> bool: + return any([ + key in COLOR_RELATED_KEYS + for _, attr_dict in self.predefined_attr_dicts + for key in attr_dict.keys() + ]) # Method alias - def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: - return self.select_parts(selector, **kwargs) + def get_parts_by_text(self, text: str, **kwargs) -> VGroup: + return self.get_parts_by_string(text, **kwargs) - def get_part_by_text(self, selector: Selector, **kwargs) -> VGroup: - return self.select_part(selector, **kwargs) + def get_part_by_text(self, text: str, **kwargs) -> VMobject: + return self.get_part_by_string(text, **kwargs) - def set_color_by_text( - self, selector: Selector, color: ManimColor, **kwargs - ): - return self.set_parts_color(selector, color, **kwargs) + def set_color_by_text(self, text: str, color: ManimColor, **kwargs): + return self.set_color_by_string(text, color, **kwargs) def set_color_by_text_to_color_map( - self, color_map: dict[Selector, ManimColor], **kwargs + self, text_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_parts_color_by_dict(color_map, **kwargs) + return self.set_color_by_string_to_color_map( + text_to_color_map, **kwargs + ) def get_text(self) -> str: return self.get_string() diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index 0f9c4d0d..dd993319 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -48,6 +48,9 @@ class ImageMobject(Mobject): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self + def set_color(self, color, opacity=None, recurse=None): + return self + def point_to_rgb(self, point: np.ndarray) -> np.ndarray: x0, y0 = self.get_corner(UL)[:2] x1, y1 = self.get_corner(DR)[:2]