diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 82fb1605..6499d0af 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -214,7 +214,7 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.get_submob_groups() + grouped_mobject = string_mobject.submob_groups digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index f824663d..dab88005 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -160,77 +160,67 @@ class TransformMatchingStrings(AnimationGroup): } def __init__(self, - source_mobject: LabelledString, - target_mobject: LabelledString, + source: LabelledString, + target: LabelledString, **kwargs ): digest_config(self, kwargs) - assert isinstance(source_mobject, LabelledString) - assert isinstance(target_mobject, LabelledString) + assert isinstance(source, LabelledString) + assert isinstance(target, LabelledString) anims = [] - rest_source_indices = list(range(len(source_mobject.submobjects))) - rest_target_indices = list(range(len(target_mobject.submobjects))) + source_indices = list(range(len(source.labelled_submobjects))) + target_indices = list(range(len(target.labelled_submobjects))) + + def get_indices_lists(mobject, parts): + return [ + [ + mobject.labelled_submobjects.index(submob) + for submob in part + ] + for part in parts + ] def add_anims_from(anim_class, func, source_args, target_args=None): if target_args is None: target_args = source_args.copy() for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source_mobject, source_arg) - target_parts = func(target_mobject, target_arg) - source_indices_lists = source_mobject.indices_lists_of_parts( - source_parts - ) - target_indices_lists = target_mobject.indices_lists_of_parts( - target_parts - ) - filtered_source_indices_lists = list(filter( + source_parts = func(source, source_arg) + target_parts = func(target, target_arg) + source_indices_lists = list(filter( lambda indices_list: all([ - index in rest_source_indices + index in source_indices for index in indices_list - ]), source_indices_lists + ]), get_indices_lists(source, source_parts) )) - filtered_target_indices_lists = list(filter( + target_indices_lists = list(filter( lambda indices_list: all([ - index in rest_target_indices + index in target_indices for index in indices_list - ]), target_indices_lists + ]), get_indices_lists(target, target_parts) )) - if not all([ - filtered_source_indices_lists, - filtered_target_indices_lists - ]): + if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) - for index in it.chain(*filtered_source_indices_lists): - rest_source_indices.remove(index) - for index in it.chain(*filtered_target_indices_lists): - rest_target_indices.remove(index) + for index in it.chain(*source_indices_lists): + source_indices.remove(index) + for index in it.chain(*target_indices_lists): + target_indices.remove(index) - def get_common_substrs(func): + def get_common_substrs(substrs_from_source, substrs_from_target): return sorted([ - substr for substr in func(source_mobject) - if substr and substr in func(target_mobject) + substr for substr in substrs_from_source + if substr and substr in substrs_from_target ], key=len, reverse=True) def get_parts_from_keys(mobject, keys): - if not isinstance(keys, tuple): - keys = (keys,) - indices = [] + if isinstance(keys, str): + keys = [keys] + result = VGroup() for key in keys: - if isinstance(key, int): - indices.append(key) - elif isinstance(key, range): - indices.extend(key) - elif isinstance(key, str): - all_parts = mobject.get_parts_by_string(key) - indices.extend(it.chain(*[ - mobject.indices_of_part(part) for part in all_parts - ])) - else: + if not isinstance(key, str): raise TypeError(key) - return VGroup(VGroup(*[ - mobject[index] for index in remove_list_redundancies(indices) - ])) + result.add(*mobject.get_parts_by_string(key)) + return result add_anims_from( ReplacementTransform, get_parts_from_keys, @@ -239,38 +229,32 @@ class TransformMatchingStrings(AnimationGroup): add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_string, - get_common_substrs(LabelledString.get_specified_substrs) + get_common_substrs( + source.specified_substrs, + target.specified_substrs + ) ) add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_group_substr, - get_common_substrs(LabelledString.get_group_substrs) + get_common_substrs( + source.group_substrs, + target.group_substrs + ) ) - fade_source = VGroup(*[ - source_mobject[index] - for index in rest_source_indices - ]) - fade_target = VGroup(*[ - target_mobject[index] - for index in rest_target_indices - ]) + rest_source = VGroup(*[source[index] for index in source_indices]) + rest_target = VGroup(*[target[index] for index in target_indices]) if self.transform_mismatches: - anims.append(ReplacementTransform( - fade_source, - fade_target, - **kwargs - )) + anims.append( + ReplacementTransform(rest_source, rest_target, **kwargs) + ) else: - anims.append(FadeOutToPoint( - fade_source, - target_mobject.get_center(), - **kwargs - )) - anims.append(FadeInFromPoint( - fade_target, - source_mobject.get_center(), - **kwargs - )) + anims.append( + FadeOutToPoint(rest_source, target.get_center(), **kwargs) + ) + anims.append( + FadeInFromPoint(rest_target, source.get_center(), **kwargs) + ) super().__init__(*anims) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index a2a9f889..58c47094 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -4,12 +4,14 @@ import re import colour import itertools as it from typing import Iterable, Union, Sequence -from abc import abstractmethod +from abc import ABC, abstractmethod -from manimlib.constants import WHITE +from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb +from manimlib.utils.color import color_to_rgb +from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies @@ -34,35 +36,39 @@ class _StringSVG(SVGMobject): } -class LabelledString(_StringSVG): +class LabelledString(_StringSVG, ABC): """ An abstract base class for `MTex` and `MarkupText` """ CONFIG = { - "base_color": None, + "base_color": WHITE, "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string - reserved_svg_default = kwargs.pop("svg_default", {}) digest_config(self, kwargs) - self.reserved_svg_default = reserved_svg_default - self.base_color = self.base_color \ - or reserved_svg_default.get("color", None) \ - or reserved_svg_default.get("fill_color", None) \ + + # Convert `base_color` to hex code. + self.base_color = rgb_to_hex(color_to_rgb( + self.base_color \ + or self.svg_default.get("color", None) \ + or self.svg_default.get("fill_color", None) \ or WHITE + )) + self.svg_default["fill_color"] = BLACK self.pre_parse() self.parse() - super().__init__(**kwargs) + super().__init__() + self.post_parse() def get_file_path(self) -> str: return self.get_file_path_(use_plain_file=False) def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_decorated_string(use_plain_file=use_plain_file) + content = self.get_content(use_plain_file) return self.get_file_path_by_content(content) @abstractmethod @@ -76,15 +82,11 @@ class LabelledString(_StringSVG): self.color_to_label(submob.get_fill_color()) for submob in self.submobjects ] - if any([ - self.use_plain_file, - self.reserved_svg_default, - self.has_predefined_colors - ]): + if self.use_plain_file or self.has_predefined_local_colors: file_path = self.get_file_path_(use_plain_file=True) plain_svg = _StringSVG( file_path, - svg_default=self.reserved_svg_default, + svg_default=self.svg_default, path_string_config=self.path_string_config ) self.set_submobjects(plain_svg.submobjects) @@ -100,7 +102,9 @@ class LabelledString(_StringSVG): def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() self.command_spans = self.get_command_spans() - self.ignored_spans = self.get_ignored_spans() + self.extra_entity_spans = self.get_extra_entity_spans() + self.entity_spans = self.get_entity_spans() + self.extra_ignored_spans = self.get_extra_ignored_spans() self.skipped_spans = self.get_skipped_spans() self.internal_specified_spans = self.get_internal_specified_spans() self.external_specified_spans = self.get_external_specified_spans() @@ -108,6 +112,20 @@ class LabelledString(_StringSVG): self.label_span_list = self.get_label_span_list() self.check_overlapping() + def post_parse(self) -> None: + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] + self.labelled_submobjects = self.get_labelled_submobjects() + self.specified_substrs = self.get_specified_substrs() + self.group_items = self.get_group_items() + self.group_substrs = self.get_group_substrs() + self.submob_groups = self.get_submob_groups() + + def copy(self): + return self.deepcopy() + # Toolkits def get_substr(self, span: Span) -> str: @@ -118,10 +136,14 @@ class LabelledString(_StringSVG): ) -> Iterable[re.Match]: return re.compile(pattern, flags).finditer(self.string, **kwargs) - def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + def search( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: return re.compile(pattern, flags).search(self.string, **kwargs) - def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + def match( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: return re.compile(pattern, flags).match(self.string, **kwargs) def find_spans(self, pattern: str, **kwargs) -> list[Span]: @@ -197,7 +219,7 @@ class LabelledString(_StringSVG): return sorted_seq[index + index_shift] @staticmethod - def get_span_replacement_dict( + def generate_span_repl_dict( inserted_string_pairs: list[tuple[Span, tuple[str, str]]], other_repl_items: list[tuple[Span, str]] ) -> dict[Span, str]: @@ -271,21 +293,19 @@ class LabelledString(_StringSVG): r, g = divmod(rg, 256) return r, g, b + @staticmethod + def int_to_hex(rgb_int: int) -> str: + return "#{:06x}".format(rgb_int).upper() + + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) + @staticmethod def color_to_label(color: ManimColor) -> int: rgb_tuple = color_to_int_rgb(color) rgb = LabelledString.rgb_to_int(rgb_tuple) - if rgb == 16777215: # white - return -1 - return rgb - - @abstractmethod - def get_begin_color_command_str(int_rgb: int) -> str: - return "" - - @abstractmethod - def get_end_color_command_str() -> str: - return "" + return rgb - 1 # Parsing @@ -296,14 +316,25 @@ class LabelledString(_StringSVG): def get_command_spans(self) -> list[Span]: return [cmd_span for cmd_span, _ in self.command_repl_items] - def get_ignored_spans(self) -> list[int]: + @abstractmethod + def get_extra_entity_spans(self) -> list[Span]: + return [] + + def get_entity_spans(self) -> list[Span]: + return list(it.chain( + self.command_spans, + self.extra_entity_spans + )) + + @abstractmethod + def get_extra_ignored_spans(self) -> list[int]: return [] def get_skipped_spans(self) -> list[Span]: return list(it.chain( self.find_spans(r"\s"), self.command_spans, - self.ignored_spans + self.extra_ignored_spans )) def shrink_span(self, span: Span) -> Span: @@ -321,14 +352,17 @@ class LabelledString(_StringSVG): return [] def get_specified_spans(self) -> list[Span]: - spans = [ - self.full_span, - *self.internal_specified_spans, - *self.external_specified_spans, - *self.find_substrs(self.isolate) - ] + spans = list(it.chain( + self.internal_specified_spans, + self.external_specified_spans, + self.find_substrs(self.isolate) + )) shrinked_spans = list(filter( - lambda span: span[0] < span[1], + lambda span: span[0] < span[1] and not any([ + entity_span[0] < index < entity_span[1] + for index in span + for entity_span in self.entity_spans + ]), [self.shrink_span(span) for span in spans] )) return remove_list_redundancies(shrinked_spans) @@ -347,40 +381,18 @@ class LabelledString(_StringSVG): ) @abstractmethod - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - return [] + def get_content(self, use_plain_file: bool) -> str: + return "" @abstractmethod - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - return [] - - def get_decorated_string(self, use_plain_file: bool) -> str: - span_repl_dict = self.get_span_replacement_dict( - self.get_inserted_string_pairs(use_plain_file), - self.get_other_repl_items(use_plain_file) - ) - result = self.get_replaced_substr(self.full_span, span_repl_dict) - - if not use_plain_file: - return result - return "".join([ - self.get_begin_color_command_str( - self.rgb_to_int(color_to_int_rgb(self.base_color)) - ), - result, - self.get_end_color_command_str() - ]) - - @abstractmethod - def has_predefined_colors(self) -> bool: + def has_predefined_local_colors(self) -> bool: return False # Post-parsing + def get_labelled_submobjects(self) -> list[VMobject]: + return [submob for _, submob in self.labelled_submobject_items] + def get_cleaned_substr(self, span: Span) -> str: span_repl_dict = dict.fromkeys(self.command_spans, "") return self.get_replaced_substr(span, span_repl_dict) @@ -391,17 +403,14 @@ class LabelledString(_StringSVG): for span in self.specified_spans ]) - def get_group_span_items(self) -> tuple[list[int], list[Span]]: - submob_labels = [submob.label for submob in self.submobjects] - if not submob_labels: - return [], [] - return tuple(zip(*self.compress_neighbours(submob_labels))) - - def get_group_substrs(self) -> list[str]: - group_labels, _ = self.get_group_span_items() - if not group_labels: + def get_group_items(self) -> list[tuple[str, VGroup]]: + if not self.labelled_submobject_items: return [] + labels, labelled_submobjects = zip(*self.labelled_submobject_items) + group_labels, labelled_submob_spans = zip( + *self.compress_neighbours(labels) + ) ordered_spans = [ self.label_span_list[label] if label != -1 else self.full_span for label in group_labels @@ -425,16 +434,27 @@ class LabelledString(_StringSVG): interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - return [ + group_substrs = [ self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in shrinked_spans ] + submob_groups = VGroup(*[ + VGroup(*labelled_submobjects[slice(*submob_span)]) + for submob_span in labelled_submob_spans + ]) + return list(zip(group_substrs, submob_groups)) - def get_submob_groups(self) -> VGroup: - _, submob_spans = self.get_group_span_items() + def get_group_substrs(self) -> list[str]: + return [group_substr for group_substr, _ in self.group_items] + + def get_submob_groups(self) -> list[VGroup]: + return [submob_group for _, submob_group in self.group_items] + + def get_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ - VGroup(*self.submobjects[slice(*submob_span)]) - for submob_span in submob_spans + group + for group_substr, group in self.group_items + if group_substr == substr ]) # Selector @@ -477,7 +497,7 @@ class LabelledString(_StringSVG): span_begin = next_begin return result - def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -487,29 +507,23 @@ class LabelledString(_StringSVG): ) ]) ] - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) + return VGroup(*[ + submob for label, submob in self.labelled_submobject_items + if label in labels + ]) def get_parts_by_string( - self, substr: str, case_sensitive: bool = True, **kwargs + self, substr: str, + case_sensitive: bool = True, regex: bool = False, **kwargs ) -> VGroup: flags = 0 if not case_sensitive: flags |= re.I + pattern = substr if regex else re.escape(substr) return VGroup(*[ - self.get_parts_by_custom_span(span, **kwargs) - for span in self.find_substr(substr, flags=flags) - ]) - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group, group_substr in zip( - self.get_submob_groups(), self.get_group_substrs() - ) - if group_substr == substr + self.get_part_by_custom_span(span, **kwargs) + for span in self.find_spans(pattern, flags=flags) + if span[0] < span[1] ]) def get_part_by_string( @@ -528,13 +542,5 @@ class LabelledString(_StringSVG): self.set_color_by_string(substr, color, **kwargs) return self - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - return [self.submobjects.index(submob) for submob in part] - - def indices_lists_of_parts( - self, parts: Iterable[Iterable[VMobject]] - ) -> list[list[int]]: - return [self.indices_of_part(part) for part in parts] - def get_string(self) -> str: return self.string diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 341db072..fb7922e1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools as it import colour from typing import Union, Sequence @@ -32,7 +33,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. if not tex_string: - tex_string = "\\quad" + tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -55,30 +56,14 @@ class MTex(LabelledString): ) def get_file_path_by_content(self, content: str) -> str: - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.tex_string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - tex_config = get_tex_config() - return tex_config["tex_body"].replace( + full_tex = tex_config["tex_body"].replace( tex_config["text_to_replace"], content ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) + with display_during_execution(f"Writing \"{self.tex_string}\""): + file_path = tex_to_svg_file(full_tex) + return file_path def pre_parse(self) -> None: super().pre_parse() @@ -91,29 +76,23 @@ class MTex(LabelledString): # Toolkits @staticmethod - def get_begin_color_command_str(rgb_int: int) -> str: + def get_color_command_str(rgb_int: int) -> str: rgb_tuple = MTex.int_to_rgb(rgb_int) return "".join([ - "{{", "\\color[RGB]", "{", ",".join(map(str, rgb_tuple)), "}" ]) - @staticmethod - def get_end_color_command_str() -> str: - return "}}" - # Pre-parsing def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - span[1] - 1 + # The latter of `\\` doesn't count. + return list(it.chain(*[ + range(span[0], span[1], 2) for span in self.find_spans(r"\\+") - if (span[1] - span[0]) % 2 == 1 - ] + ])) def get_unescaped_char_spans(self, chars: str): return sorted(filter( @@ -209,13 +188,19 @@ class MTex(LabelledString): right_brace_indices, cmd_end, n_braces ) + 1 if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" + repl_str = "\\" + cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - def get_ignored_spans(self) -> list[int]: + def get_extra_entity_spans(self) -> list[Span]: + return [ + self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_extra_ignored_spans(self) -> list[int]: return self.script_char_spans.copy() def get_internal_specified_spans(self) -> list[Span]: @@ -256,35 +241,46 @@ class MTex(LabelledString): result.append(shrinked_span) return result - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: + def get_content(self, use_plain_file: bool) -> str: if use_plain_file: - return [] + span_repl_dict = {} + else: + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] + inserted_string_pairs = [ + (span, ( + "{{" + self.get_color_command_str(label + 1), + "}}" + )) + for label, span in enumerate(extended_label_span_list) + ] + span_repl_dict = self.generate_span_repl_dict( + inserted_string_pairs, + self.command_repl_items + ) + result = self.get_replaced_substr(self.full_span, span_repl_dict) - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] - return [ - (span, ( - self.get_begin_color_command_str(label), - self.get_end_color_command_str() - )) - for label, span in enumerate(extended_label_span_list) - ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: + if self.tex_environment: + result = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + result, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + result = "\n".join([self.alignment, result]) if use_plain_file: - return [] - return self.command_repl_items.copy() + result = "\n".join([ + self.get_color_command_str(self.hex_to_int(self.base_color)), + result + ]) + return result @property - def has_predefined_colors(self) -> bool: + def has_predefined_local_colors(self) -> bool: return bool(self.command_repl_items) # Post-parsing diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 8dbd05cc..c3c3be19 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -254,27 +254,6 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @staticmethod - def get_begin_tag_str(attr_dict: dict[str, str]) -> str: - return f"" - - @staticmethod - def get_end_tag_str() -> str: - return "" - - @staticmethod - def rgb_int_to_hex(rgb_int: int) -> str: - return "#{:06x}".format(rgb_int).upper() - - @staticmethod - def get_begin_color_command_str(rgb_int: int): - color_hex = MarkupText.rgb_int_to_hex(rgb_int) - return MarkupText.get_begin_tag_str({"foreground": color_hex}) - - @staticmethod - def get_end_color_command_str() -> str: - return MarkupText.get_end_tag_str() - @staticmethod def merge_attr_dicts( attr_dict_items: list[Span, str, typing.Any] @@ -452,6 +431,14 @@ class MarkupText(LabelledString): ] return result + def get_extra_entity_spans(self) -> list[Span]: + if not self.is_markup: + return [] + return self.find_spans(r"&.*?;") + + def get_extra_ignored_spans(self) -> list[int]: + return [] + def get_internal_specified_spans(self) -> list[Span]: return [span for span, _ in self.local_dicts_from_markup] @@ -464,13 +451,10 @@ class MarkupText(LabelledString): self.find_spans(r"\b"), self.specified_spans )))) - entity_spans = self.command_spans.copy() - if self.is_markup: - entity_spans += self.find_spans(r"&.*?;") breakup_indices = sorted(filter( lambda index: not any([ span[0] < index < span[1] - for span in entity_spans + for span in self.entity_spans ]), breakup_indices )) @@ -479,40 +463,45 @@ class MarkupText(LabelledString): self.get_neighbouring_pairs(breakup_indices) )) - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - if not use_plain_file: + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: attr_dict_items = [ - (span, { - key: WHITE if key in COLOR_RELATED_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in self.predefined_attr_dicts - ] + [ - (span, {"foreground": self.rgb_int_to_hex(label)}) - for label, span in enumerate(self.label_span_list) + (self.full_span, {"foreground": self.base_color}), + *self.predefined_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] ] else: - attr_dict_items = self.predefined_attr_dicts + [ - (span, {}) - for span in self.label_span_list + attr_dict_items = [ + (self.full_span, {"foreground": BLACK}), + *[ + (span, { + key: BLACK if key in COLOR_RELATED_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.predefined_attr_dicts + ], + *[ + (span, {"foreground": self.int_to_hex(label + 1)}) + for label, span in enumerate(self.label_span_list) + ] ] - return [ + inserted_string_pairs = [ (span, ( - self.get_begin_tag_str(attr_dict), - self.get_end_tag_str() + f"", + "" )) for span, attr_dict in self.merge_attr_dicts(attr_dict_items) ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - return self.command_repl_items.copy() + span_repl_dict = self.generate_span_repl_dict( + inserted_string_pairs, self.command_repl_items + ) + return self.get_replaced_substr(self.full_span, span_repl_dict) @property - def has_predefined_colors(self) -> bool: + def has_predefined_local_colors(self) -> bool: return any([ key in COLOR_RELATED_KEYS for _, attr_dict in self.predefined_attr_dicts