Merge pull request #1785 from YishiMichael/master

Fix bug when handling multi-line tex
This commit is contained in:
Grant Sanderson
2022-04-11 09:52:51 -07:00
committed by GitHub
5 changed files with 263 additions and 288 deletions

View File

@ -214,7 +214,7 @@ class AddTextWordByWord(ShowIncreasingSubsets):
def __init__(self, string_mobject, **kwargs): def __init__(self, string_mobject, **kwargs):
assert isinstance(string_mobject, LabelledString) assert isinstance(string_mobject, LabelledString)
grouped_mobject = string_mobject.get_submob_groups() grouped_mobject = string_mobject.submob_groups
digest_config(self, kwargs) digest_config(self, kwargs)
if self.run_time is None: if self.run_time is None:
self.run_time = self.time_per_word * len(grouped_mobject) self.run_time = self.time_per_word * len(grouped_mobject)

View File

@ -160,77 +160,67 @@ class TransformMatchingStrings(AnimationGroup):
} }
def __init__(self, def __init__(self,
source_mobject: LabelledString, source: LabelledString,
target_mobject: LabelledString, target: LabelledString,
**kwargs **kwargs
): ):
digest_config(self, kwargs) digest_config(self, kwargs)
assert isinstance(source_mobject, LabelledString) assert isinstance(source, LabelledString)
assert isinstance(target_mobject, LabelledString) assert isinstance(target, LabelledString)
anims = [] anims = []
rest_source_indices = list(range(len(source_mobject.submobjects))) source_indices = list(range(len(source.labelled_submobjects)))
rest_target_indices = list(range(len(target_mobject.submobjects))) target_indices = list(range(len(target.labelled_submobjects)))
def get_indices_lists(mobject, parts):
return [
[
mobject.labelled_submobjects.index(submob)
for submob in part
]
for part in parts
]
def add_anims_from(anim_class, func, source_args, target_args=None): def add_anims_from(anim_class, func, source_args, target_args=None):
if target_args is None: if target_args is None:
target_args = source_args.copy() target_args = source_args.copy()
for source_arg, target_arg in zip(source_args, target_args): for source_arg, target_arg in zip(source_args, target_args):
source_parts = func(source_mobject, source_arg) source_parts = func(source, source_arg)
target_parts = func(target_mobject, target_arg) target_parts = func(target, target_arg)
source_indices_lists = source_mobject.indices_lists_of_parts( source_indices_lists = list(filter(
source_parts
)
target_indices_lists = target_mobject.indices_lists_of_parts(
target_parts
)
filtered_source_indices_lists = list(filter(
lambda indices_list: all([ lambda indices_list: all([
index in rest_source_indices index in source_indices
for index in indices_list for index in indices_list
]), source_indices_lists ]), get_indices_lists(source, source_parts)
)) ))
filtered_target_indices_lists = list(filter( target_indices_lists = list(filter(
lambda indices_list: all([ lambda indices_list: all([
index in rest_target_indices index in target_indices
for index in indices_list for index in indices_list
]), target_indices_lists ]), get_indices_lists(target, target_parts)
)) ))
if not all([ if not source_indices_lists or not target_indices_lists:
filtered_source_indices_lists,
filtered_target_indices_lists
]):
continue continue
anims.append(anim_class(source_parts, target_parts, **kwargs)) anims.append(anim_class(source_parts, target_parts, **kwargs))
for index in it.chain(*filtered_source_indices_lists): for index in it.chain(*source_indices_lists):
rest_source_indices.remove(index) source_indices.remove(index)
for index in it.chain(*filtered_target_indices_lists): for index in it.chain(*target_indices_lists):
rest_target_indices.remove(index) target_indices.remove(index)
def get_common_substrs(func): def get_common_substrs(substrs_from_source, substrs_from_target):
return sorted([ return sorted([
substr for substr in func(source_mobject) substr for substr in substrs_from_source
if substr and substr in func(target_mobject) if substr and substr in substrs_from_target
], key=len, reverse=True) ], key=len, reverse=True)
def get_parts_from_keys(mobject, keys): def get_parts_from_keys(mobject, keys):
if not isinstance(keys, tuple): if isinstance(keys, str):
keys = (keys,) keys = [keys]
indices = [] result = VGroup()
for key in keys: for key in keys:
if isinstance(key, int): if not isinstance(key, str):
indices.append(key)
elif isinstance(key, range):
indices.extend(key)
elif isinstance(key, str):
all_parts = mobject.get_parts_by_string(key)
indices.extend(it.chain(*[
mobject.indices_of_part(part) for part in all_parts
]))
else:
raise TypeError(key) raise TypeError(key)
return VGroup(VGroup(*[ result.add(*mobject.get_parts_by_string(key))
mobject[index] for index in remove_list_redundancies(indices) return result
]))
add_anims_from( add_anims_from(
ReplacementTransform, get_parts_from_keys, ReplacementTransform, get_parts_from_keys,
@ -239,38 +229,32 @@ class TransformMatchingStrings(AnimationGroup):
add_anims_from( add_anims_from(
FadeTransformPieces, FadeTransformPieces,
LabelledString.get_parts_by_string, LabelledString.get_parts_by_string,
get_common_substrs(LabelledString.get_specified_substrs) get_common_substrs(
source.specified_substrs,
target.specified_substrs
)
) )
add_anims_from( add_anims_from(
FadeTransformPieces, FadeTransformPieces,
LabelledString.get_parts_by_group_substr, LabelledString.get_parts_by_group_substr,
get_common_substrs(LabelledString.get_group_substrs) get_common_substrs(
source.group_substrs,
target.group_substrs
)
) )
fade_source = VGroup(*[ rest_source = VGroup(*[source[index] for index in source_indices])
source_mobject[index] rest_target = VGroup(*[target[index] for index in target_indices])
for index in rest_source_indices
])
fade_target = VGroup(*[
target_mobject[index]
for index in rest_target_indices
])
if self.transform_mismatches: if self.transform_mismatches:
anims.append(ReplacementTransform( anims.append(
fade_source, ReplacementTransform(rest_source, rest_target, **kwargs)
fade_target, )
**kwargs
))
else: else:
anims.append(FadeOutToPoint( anims.append(
fade_source, FadeOutToPoint(rest_source, target.get_center(), **kwargs)
target_mobject.get_center(), )
**kwargs anims.append(
)) FadeInFromPoint(rest_target, source.get_center(), **kwargs)
anims.append(FadeInFromPoint( )
fade_target,
source_mobject.get_center(),
**kwargs
))
super().__init__(*anims) super().__init__(*anims)

View File

@ -4,12 +4,14 @@ import re
import colour import colour
import itertools as it import itertools as it
from typing import Iterable, Union, Sequence from typing import Iterable, Union, Sequence
from abc import abstractmethod from abc import ABC, abstractmethod
from manimlib.constants import WHITE from manimlib.constants import BLACK, WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.iterables import remove_list_redundancies
@ -34,35 +36,39 @@ class _StringSVG(SVGMobject):
} }
class LabelledString(_StringSVG): class LabelledString(_StringSVG, ABC):
""" """
An abstract base class for `MTex` and `MarkupText` An abstract base class for `MTex` and `MarkupText`
""" """
CONFIG = { CONFIG = {
"base_color": None, "base_color": WHITE,
"use_plain_file": False, "use_plain_file": False,
"isolate": [], "isolate": [],
} }
def __init__(self, string: str, **kwargs): def __init__(self, string: str, **kwargs):
self.string = string self.string = string
reserved_svg_default = kwargs.pop("svg_default", {})
digest_config(self, kwargs) digest_config(self, kwargs)
self.reserved_svg_default = reserved_svg_default
self.base_color = self.base_color \ # Convert `base_color` to hex code.
or reserved_svg_default.get("color", None) \ self.base_color = rgb_to_hex(color_to_rgb(
or reserved_svg_default.get("fill_color", None) \ self.base_color \
or self.svg_default.get("color", None) \
or self.svg_default.get("fill_color", None) \
or WHITE or WHITE
))
self.svg_default["fill_color"] = BLACK
self.pre_parse() self.pre_parse()
self.parse() self.parse()
super().__init__(**kwargs) super().__init__()
self.post_parse()
def get_file_path(self) -> str: def get_file_path(self) -> str:
return self.get_file_path_(use_plain_file=False) return self.get_file_path_(use_plain_file=False)
def get_file_path_(self, use_plain_file: bool) -> str: def get_file_path_(self, use_plain_file: bool) -> str:
content = self.get_decorated_string(use_plain_file=use_plain_file) content = self.get_content(use_plain_file)
return self.get_file_path_by_content(content) return self.get_file_path_by_content(content)
@abstractmethod @abstractmethod
@ -76,15 +82,11 @@ class LabelledString(_StringSVG):
self.color_to_label(submob.get_fill_color()) self.color_to_label(submob.get_fill_color())
for submob in self.submobjects for submob in self.submobjects
] ]
if any([ if self.use_plain_file or self.has_predefined_local_colors:
self.use_plain_file,
self.reserved_svg_default,
self.has_predefined_colors
]):
file_path = self.get_file_path_(use_plain_file=True) file_path = self.get_file_path_(use_plain_file=True)
plain_svg = _StringSVG( plain_svg = _StringSVG(
file_path, file_path,
svg_default=self.reserved_svg_default, svg_default=self.svg_default,
path_string_config=self.path_string_config path_string_config=self.path_string_config
) )
self.set_submobjects(plain_svg.submobjects) self.set_submobjects(plain_svg.submobjects)
@ -100,7 +102,9 @@ class LabelledString(_StringSVG):
def parse(self) -> None: def parse(self) -> None:
self.command_repl_items = self.get_command_repl_items() self.command_repl_items = self.get_command_repl_items()
self.command_spans = self.get_command_spans() self.command_spans = self.get_command_spans()
self.ignored_spans = self.get_ignored_spans() self.extra_entity_spans = self.get_extra_entity_spans()
self.entity_spans = self.get_entity_spans()
self.extra_ignored_spans = self.get_extra_ignored_spans()
self.skipped_spans = self.get_skipped_spans() self.skipped_spans = self.get_skipped_spans()
self.internal_specified_spans = self.get_internal_specified_spans() self.internal_specified_spans = self.get_internal_specified_spans()
self.external_specified_spans = self.get_external_specified_spans() self.external_specified_spans = self.get_external_specified_spans()
@ -108,6 +112,20 @@ class LabelledString(_StringSVG):
self.label_span_list = self.get_label_span_list() self.label_span_list = self.get_label_span_list()
self.check_overlapping() self.check_overlapping()
def post_parse(self) -> None:
self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
self.labelled_submobjects = self.get_labelled_submobjects()
self.specified_substrs = self.get_specified_substrs()
self.group_items = self.get_group_items()
self.group_substrs = self.get_group_substrs()
self.submob_groups = self.get_submob_groups()
def copy(self):
return self.deepcopy()
# Toolkits # Toolkits
def get_substr(self, span: Span) -> str: def get_substr(self, span: Span) -> str:
@ -118,10 +136,14 @@ class LabelledString(_StringSVG):
) -> Iterable[re.Match]: ) -> Iterable[re.Match]:
return re.compile(pattern, flags).finditer(self.string, **kwargs) return re.compile(pattern, flags).finditer(self.string, **kwargs)
def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: def search(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).search(self.string, **kwargs) return re.compile(pattern, flags).search(self.string, **kwargs)
def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: def match(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).match(self.string, **kwargs) return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str, **kwargs) -> list[Span]: def find_spans(self, pattern: str, **kwargs) -> list[Span]:
@ -197,7 +219,7 @@ class LabelledString(_StringSVG):
return sorted_seq[index + index_shift] return sorted_seq[index + index_shift]
@staticmethod @staticmethod
def get_span_replacement_dict( def generate_span_repl_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]], inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
other_repl_items: list[tuple[Span, str]] other_repl_items: list[tuple[Span, str]]
) -> dict[Span, str]: ) -> dict[Span, str]:
@ -271,21 +293,19 @@ class LabelledString(_StringSVG):
r, g = divmod(rg, 256) r, g = divmod(rg, 256)
return r, g, b return r, g, b
@staticmethod
def int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod @staticmethod
def color_to_label(color: ManimColor) -> int: def color_to_label(color: ManimColor) -> int:
rgb_tuple = color_to_int_rgb(color) rgb_tuple = color_to_int_rgb(color)
rgb = LabelledString.rgb_to_int(rgb_tuple) rgb = LabelledString.rgb_to_int(rgb_tuple)
if rgb == 16777215: # white return rgb - 1
return -1
return rgb
@abstractmethod
def get_begin_color_command_str(int_rgb: int) -> str:
return ""
@abstractmethod
def get_end_color_command_str() -> str:
return ""
# Parsing # Parsing
@ -296,14 +316,25 @@ class LabelledString(_StringSVG):
def get_command_spans(self) -> list[Span]: def get_command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items] return [cmd_span for cmd_span, _ in self.command_repl_items]
def get_ignored_spans(self) -> list[int]: @abstractmethod
def get_extra_entity_spans(self) -> list[Span]:
return []
def get_entity_spans(self) -> list[Span]:
return list(it.chain(
self.command_spans,
self.extra_entity_spans
))
@abstractmethod
def get_extra_ignored_spans(self) -> list[int]:
return [] return []
def get_skipped_spans(self) -> list[Span]: def get_skipped_spans(self) -> list[Span]:
return list(it.chain( return list(it.chain(
self.find_spans(r"\s"), self.find_spans(r"\s"),
self.command_spans, self.command_spans,
self.ignored_spans self.extra_ignored_spans
)) ))
def shrink_span(self, span: Span) -> Span: def shrink_span(self, span: Span) -> Span:
@ -321,14 +352,17 @@ class LabelledString(_StringSVG):
return [] return []
def get_specified_spans(self) -> list[Span]: def get_specified_spans(self) -> list[Span]:
spans = [ spans = list(it.chain(
self.full_span, self.internal_specified_spans,
*self.internal_specified_spans, self.external_specified_spans,
*self.external_specified_spans, self.find_substrs(self.isolate)
*self.find_substrs(self.isolate) ))
]
shrinked_spans = list(filter( shrinked_spans = list(filter(
lambda span: span[0] < span[1], lambda span: span[0] < span[1] and not any([
entity_span[0] < index < entity_span[1]
for index in span
for entity_span in self.entity_spans
]),
[self.shrink_span(span) for span in spans] [self.shrink_span(span) for span in spans]
)) ))
return remove_list_redundancies(shrinked_spans) return remove_list_redundancies(shrinked_spans)
@ -347,40 +381,18 @@ class LabelledString(_StringSVG):
) )
@abstractmethod @abstractmethod
def get_inserted_string_pairs( def get_content(self, use_plain_file: bool) -> str:
self, use_plain_file: bool return ""
) -> list[tuple[Span, tuple[str, str]]]:
return []
@abstractmethod @abstractmethod
def get_other_repl_items( def has_predefined_local_colors(self) -> bool:
self, use_plain_file: bool
) -> list[tuple[Span, str]]:
return []
def get_decorated_string(self, use_plain_file: bool) -> str:
span_repl_dict = self.get_span_replacement_dict(
self.get_inserted_string_pairs(use_plain_file),
self.get_other_repl_items(use_plain_file)
)
result = self.get_replaced_substr(self.full_span, span_repl_dict)
if not use_plain_file:
return result
return "".join([
self.get_begin_color_command_str(
self.rgb_to_int(color_to_int_rgb(self.base_color))
),
result,
self.get_end_color_command_str()
])
@abstractmethod
def has_predefined_colors(self) -> bool:
return False return False
# Post-parsing # Post-parsing
def get_labelled_submobjects(self) -> list[VMobject]:
return [submob for _, submob in self.labelled_submobject_items]
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
span_repl_dict = dict.fromkeys(self.command_spans, "") span_repl_dict = dict.fromkeys(self.command_spans, "")
return self.get_replaced_substr(span, span_repl_dict) return self.get_replaced_substr(span, span_repl_dict)
@ -391,17 +403,14 @@ class LabelledString(_StringSVG):
for span in self.specified_spans for span in self.specified_spans
]) ])
def get_group_span_items(self) -> tuple[list[int], list[Span]]: def get_group_items(self) -> list[tuple[str, VGroup]]:
submob_labels = [submob.label for submob in self.submobjects] if not self.labelled_submobject_items:
if not submob_labels:
return [], []
return tuple(zip(*self.compress_neighbours(submob_labels)))
def get_group_substrs(self) -> list[str]:
group_labels, _ = self.get_group_span_items()
if not group_labels:
return [] return []
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
group_labels, labelled_submob_spans = zip(
*self.compress_neighbours(labels)
)
ordered_spans = [ ordered_spans = [
self.label_span_list[label] if label != -1 else self.full_span self.label_span_list[label] if label != -1 else self.full_span
for label in group_labels for label in group_labels
@ -425,16 +434,27 @@ class LabelledString(_StringSVG):
interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) interval_spans, (ordered_spans[0][0], ordered_spans[-1][1])
) )
] ]
return [ group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else "" self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in shrinked_spans for span in shrinked_spans
] ]
submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)])
for submob_span in labelled_submob_spans
])
return list(zip(group_substrs, submob_groups))
def get_submob_groups(self) -> VGroup: def get_group_substrs(self) -> list[str]:
_, submob_spans = self.get_group_span_items() return [group_substr for group_substr, _ in self.group_items]
def get_submob_groups(self) -> list[VGroup]:
return [submob_group for _, submob_group in self.group_items]
def get_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[ return VGroup(*[
VGroup(*self.submobjects[slice(*submob_span)]) group
for submob_span in submob_spans for group_substr, group in self.group_items
if group_substr == substr
]) ])
# Selector # Selector
@ -477,7 +497,7 @@ class LabelledString(_StringSVG):
span_begin = next_begin span_begin = next_begin
return result return result
def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [ labels = [
label for label, span in enumerate(self.label_span_list) label for label, span in enumerate(self.label_span_list)
if any([ if any([
@ -487,29 +507,23 @@ class LabelledString(_StringSVG):
) )
]) ])
] ]
return VGroup(*filter( return VGroup(*[
lambda submob: submob.label in labels, submob for label, submob in self.labelled_submobject_items
self.submobjects if label in labels
)) ])
def get_parts_by_string( def get_parts_by_string(
self, substr: str, case_sensitive: bool = True, **kwargs self, substr: str,
case_sensitive: bool = True, regex: bool = False, **kwargs
) -> VGroup: ) -> VGroup:
flags = 0 flags = 0
if not case_sensitive: if not case_sensitive:
flags |= re.I flags |= re.I
pattern = substr if regex else re.escape(substr)
return VGroup(*[ return VGroup(*[
self.get_parts_by_custom_span(span, **kwargs) self.get_part_by_custom_span(span, **kwargs)
for span in self.find_substr(substr, flags=flags) for span in self.find_spans(pattern, flags=flags)
]) if span[0] < span[1]
def get_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[
group
for group, group_substr in zip(
self.get_submob_groups(), self.get_group_substrs()
)
if group_substr == substr
]) ])
def get_part_by_string( def get_part_by_string(
@ -528,13 +542,5 @@ class LabelledString(_StringSVG):
self.set_color_by_string(substr, color, **kwargs) self.set_color_by_string(substr, color, **kwargs)
return self return self
def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:
return [self.submobjects.index(submob) for submob in part]
def indices_lists_of_parts(
self, parts: Iterable[Iterable[VMobject]]
) -> list[list[int]]:
return [self.indices_of_part(part) for part in parts]
def get_string(self) -> str: def get_string(self) -> str:
return self.string return self.string

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import itertools as it
import colour import colour
from typing import Union, Sequence from typing import Union, Sequence
@ -32,7 +33,7 @@ class MTex(LabelledString):
def __init__(self, tex_string: str, **kwargs): def __init__(self, tex_string: str, **kwargs):
# Prevent from passing an empty string. # Prevent from passing an empty string.
if not tex_string: if not tex_string:
tex_string = "\\quad" tex_string = "\\\\"
self.tex_string = tex_string self.tex_string = tex_string
super().__init__(tex_string, **kwargs) super().__init__(tex_string, **kwargs)
@ -55,30 +56,14 @@ class MTex(LabelledString):
) )
def get_file_path_by_content(self, content: str) -> str: def get_file_path_by_content(self, content: str) -> str:
full_tex = self.get_tex_file_body(content)
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = self.tex_to_svg_file_path(full_tex)
return file_path
def get_tex_file_body(self, content: str) -> str:
if self.tex_environment:
content = "\n".join([
f"\\begin{{{self.tex_environment}}}",
content,
f"\\end{{{self.tex_environment}}}"
])
if self.alignment:
content = "\n".join([self.alignment, content])
tex_config = get_tex_config() tex_config = get_tex_config()
return tex_config["tex_body"].replace( full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"], tex_config["text_to_replace"],
content content
) )
with display_during_execution(f"Writing \"{self.tex_string}\""):
@staticmethod file_path = tex_to_svg_file(full_tex)
def tex_to_svg_file_path(tex_file_content: str) -> str: return file_path
return tex_to_svg_file(tex_file_content)
def pre_parse(self) -> None: def pre_parse(self) -> None:
super().pre_parse() super().pre_parse()
@ -91,29 +76,23 @@ class MTex(LabelledString):
# Toolkits # Toolkits
@staticmethod @staticmethod
def get_begin_color_command_str(rgb_int: int) -> str: def get_color_command_str(rgb_int: int) -> str:
rgb_tuple = MTex.int_to_rgb(rgb_int) rgb_tuple = MTex.int_to_rgb(rgb_int)
return "".join([ return "".join([
"{{",
"\\color[RGB]", "\\color[RGB]",
"{", "{",
",".join(map(str, rgb_tuple)), ",".join(map(str, rgb_tuple)),
"}" "}"
]) ])
@staticmethod
def get_end_color_command_str() -> str:
return "}}"
# Pre-parsing # Pre-parsing
def get_backslash_indices(self) -> list[int]: def get_backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count. # The latter of `\\` doesn't count.
return [ return list(it.chain(*[
span[1] - 1 range(span[0], span[1], 2)
for span in self.find_spans(r"\\+") for span in self.find_spans(r"\\+")
if (span[1] - span[0]) % 2 == 1 ]))
]
def get_unescaped_char_spans(self, chars: str): def get_unescaped_char_spans(self, chars: str):
return sorted(filter( return sorted(filter(
@ -209,13 +188,19 @@ class MTex(LabelledString):
right_brace_indices, cmd_end, n_braces right_brace_indices, cmd_end, n_braces
) + 1 ) + 1
if substitute_cmd: if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{white}" repl_str = "\\" + cmd_name + n_braces * "{black}"
else: else:
repl_str = "" repl_str = ""
result.append(((span_begin, span_end), repl_str)) result.append(((span_begin, span_end), repl_str))
return result return result
def get_ignored_spans(self) -> list[int]: def get_extra_entity_spans(self) -> list[Span]:
return [
self.match(r"\\([a-zA-Z]+|.)", pos=index).span()
for index in self.backslash_indices
]
def get_extra_ignored_spans(self) -> list[int]:
return self.script_char_spans.copy() return self.script_char_spans.copy()
def get_internal_specified_spans(self) -> list[Span]: def get_internal_specified_spans(self) -> list[Span]:
@ -256,35 +241,46 @@ class MTex(LabelledString):
result.append(shrinked_span) result.append(shrinked_span)
return result return result
def get_inserted_string_pairs( def get_content(self, use_plain_file: bool) -> str:
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
if use_plain_file: if use_plain_file:
return [] span_repl_dict = {}
else:
extended_label_span_list = [
span
if span in self.script_content_spans
else (span[0], self.rslide(span[1], self.script_spans))
for span in self.label_span_list
]
inserted_string_pairs = [
(span, (
"{{" + self.get_color_command_str(label + 1),
"}}"
))
for label, span in enumerate(extended_label_span_list)
]
span_repl_dict = self.generate_span_repl_dict(
inserted_string_pairs,
self.command_repl_items
)
result = self.get_replaced_substr(self.full_span, span_repl_dict)
extended_label_span_list = [ if self.tex_environment:
span result = "\n".join([
if span in self.script_content_spans f"\\begin{{{self.tex_environment}}}",
else (span[0], self.rslide(span[1], self.script_spans)) result,
for span in self.label_span_list f"\\end{{{self.tex_environment}}}"
] ])
return [ if self.alignment:
(span, ( result = "\n".join([self.alignment, result])
self.get_begin_color_command_str(label),
self.get_end_color_command_str()
))
for label, span in enumerate(extended_label_span_list)
]
def get_other_repl_items(
self, use_plain_file: bool
) -> list[tuple[Span, str]]:
if use_plain_file: if use_plain_file:
return [] result = "\n".join([
return self.command_repl_items.copy() self.get_color_command_str(self.hex_to_int(self.base_color)),
result
])
return result
@property @property
def has_predefined_colors(self) -> bool: def has_predefined_local_colors(self) -> bool:
return bool(self.command_repl_items) return bool(self.command_repl_items)
# Post-parsing # Post-parsing

View File

@ -254,27 +254,6 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items() for key, val in attr_dict.items()
]) ])
@staticmethod
def get_begin_tag_str(attr_dict: dict[str, str]) -> str:
return f"<span {MarkupText.get_attr_dict_str(attr_dict)}>"
@staticmethod
def get_end_tag_str() -> str:
return "</span>"
@staticmethod
def rgb_int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def get_begin_color_command_str(rgb_int: int):
color_hex = MarkupText.rgb_int_to_hex(rgb_int)
return MarkupText.get_begin_tag_str({"foreground": color_hex})
@staticmethod
def get_end_color_command_str() -> str:
return MarkupText.get_end_tag_str()
@staticmethod @staticmethod
def merge_attr_dicts( def merge_attr_dicts(
attr_dict_items: list[Span, str, typing.Any] attr_dict_items: list[Span, str, typing.Any]
@ -452,6 +431,14 @@ class MarkupText(LabelledString):
] ]
return result return result
def get_extra_entity_spans(self) -> list[Span]:
if not self.is_markup:
return []
return self.find_spans(r"&.*?;")
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_internal_specified_spans(self) -> list[Span]: def get_internal_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_markup] return [span for span, _ in self.local_dicts_from_markup]
@ -464,13 +451,10 @@ class MarkupText(LabelledString):
self.find_spans(r"\b"), self.find_spans(r"\b"),
self.specified_spans self.specified_spans
)))) ))))
entity_spans = self.command_spans.copy()
if self.is_markup:
entity_spans += self.find_spans(r"&.*?;")
breakup_indices = sorted(filter( breakup_indices = sorted(filter(
lambda index: not any([ lambda index: not any([
span[0] < index < span[1] span[0] < index < span[1]
for span in entity_spans for span in self.entity_spans
]), ]),
breakup_indices breakup_indices
)) ))
@ -479,40 +463,45 @@ class MarkupText(LabelledString):
self.get_neighbouring_pairs(breakup_indices) self.get_neighbouring_pairs(breakup_indices)
)) ))
def get_inserted_string_pairs( def get_content(self, use_plain_file: bool) -> str:
self, use_plain_file: bool if use_plain_file:
) -> list[tuple[Span, tuple[str, str]]]:
if not use_plain_file:
attr_dict_items = [ attr_dict_items = [
(span, { (self.full_span, {"foreground": self.base_color}),
key: WHITE if key in COLOR_RELATED_KEYS else val *self.predefined_attr_dicts,
for key, val in attr_dict.items() *[
}) (span, {})
for span, attr_dict in self.predefined_attr_dicts for span in self.label_span_list
] + [ ]
(span, {"foreground": self.rgb_int_to_hex(label)})
for label, span in enumerate(self.label_span_list)
] ]
else: else:
attr_dict_items = self.predefined_attr_dicts + [ attr_dict_items = [
(span, {}) (self.full_span, {"foreground": BLACK}),
for span in self.label_span_list *[
(span, {
key: BLACK if key in COLOR_RELATED_KEYS else val
for key, val in attr_dict.items()
})
for span, attr_dict in self.predefined_attr_dicts
],
*[
(span, {"foreground": self.int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list)
]
] ]
return [ inserted_string_pairs = [
(span, ( (span, (
self.get_begin_tag_str(attr_dict), f"<span {self.get_attr_dict_str(attr_dict)}>",
self.get_end_tag_str() "</span>"
)) ))
for span, attr_dict in self.merge_attr_dicts(attr_dict_items) for span, attr_dict in self.merge_attr_dicts(attr_dict_items)
] ]
span_repl_dict = self.generate_span_repl_dict(
def get_other_repl_items( inserted_string_pairs, self.command_repl_items
self, use_plain_file: bool )
) -> list[tuple[Span, str]]: return self.get_replaced_substr(self.full_span, span_repl_dict)
return self.command_repl_items.copy()
@property @property
def has_predefined_colors(self) -> bool: def has_predefined_local_colors(self) -> bool:
return any([ return any([
key in COLOR_RELATED_KEYS key in COLOR_RELATED_KEYS
for _, attr_dict in self.predefined_attr_dicts for _, attr_dict in self.predefined_attr_dicts