Refactor LabelledString and relevant classes

This commit is contained in:
YishiMichael
2022-04-18 18:47:57 +08:00
parent 0e0244128c
commit cbb7e69f68
6 changed files with 100 additions and 135 deletions

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np
from manimlib.constants import DOWN, LEFT, RIGHT, UP from manimlib.constants import DOWN, LEFT, RIGHT, UP
from manimlib.constants import GREY_B from manimlib.constants import GREY_B
from manimlib.constants import MED_SMALL_BUFF from manimlib.constants import MED_SMALL_BUFF

View File

@ -53,11 +53,16 @@ class LabelledString(SVGMobject, ABC):
digest_config(self, kwargs) digest_config(self, kwargs)
if self.base_color is None: if self.base_color is None:
self.base_color = WHITE self.base_color = WHITE
self.base_color_int = self.color_to_int(self.base_color)
self.pre_parse() self.string_len = len(self.string)
self.full_span = (0, self.string_len)
self.parse() self.parse()
super().__init__() super().__init__()
self.post_parse() self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
def get_file_path(self) -> str: def get_file_path(self) -> str:
return self.get_file_path_(is_labelled=False) return self.get_file_path_(is_labelled=False)
@ -85,7 +90,6 @@ class LabelledString(SVGMobject, ABC):
submob_color_ints = [0] * len(self.submobjects) submob_color_ints = [0] * len(self.submobjects)
if len(self.submobjects) != len(submob_color_ints): if len(self.submobjects) != len(submob_color_ints):
print(len(self.submobjects), len(submob_color_ints))
raise ValueError( raise ValueError(
"Cannot align submobjects of the labelled svg " "Cannot align submobjects of the labelled svg "
"to the original svg" "to the original svg"
@ -104,11 +108,6 @@ class LabelledString(SVGMobject, ABC):
for submob, color_int in zip(self.submobjects, submob_color_ints): for submob, color_int in zip(self.submobjects, submob_color_ints):
submob.label = color_int - 1 submob.label = color_int - 1
def pre_parse(self) -> None:
self.string_len = len(self.string)
self.full_span = (0, self.string_len)
self.base_color_int = self.color_to_int(self.base_color)
def parse(self) -> None: def parse(self) -> None:
self.skippable_indices = self.get_skippable_indices() self.skippable_indices = self.get_skippable_indices()
self.entity_spans = self.get_entity_spans() self.entity_spans = self.get_entity_spans()
@ -121,12 +120,6 @@ class LabelledString(SVGMobject, ABC):
if len(self.label_span_list) >= 16777216: if len(self.label_span_list) >= 16777216:
raise ValueError("Cannot handle that many substrings") raise ValueError("Cannot handle that many substrings")
def post_parse(self) -> None:
self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
def copy(self): def copy(self):
return self.deepcopy() return self.deepcopy()
@ -362,7 +355,7 @@ class LabelledString(SVGMobject, ABC):
def get_content(self, is_labelled: bool) -> str: def get_content(self, is_labelled: bool) -> str:
return "" return ""
# Post-parsing # Selector
@abstractmethod @abstractmethod
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
@ -414,8 +407,6 @@ class LabelledString(SVGMobject, ABC):
for span in self.specified_spans for span in self.specified_spans
] ]
# Selector
def find_span_components( def find_span_components(
self, custom_span: Span, substring: bool = True self, custom_span: Span, substring: bool = True
) -> list[Span]: ) -> list[Span]:

View File

@ -84,8 +84,7 @@ class MTex(LabelledString):
file_path = tex_to_svg_file(full_tex) file_path = tex_to_svg_file(full_tex)
return file_path return file_path
def pre_parse(self) -> None: def parse(self) -> None:
super().pre_parse()
self.backslash_indices = self.get_backslash_indices() self.backslash_indices = self.get_backslash_indices()
self.command_spans = self.get_command_spans() self.command_spans = self.get_command_spans()
self.brace_spans = self.get_brace_spans() self.brace_spans = self.get_brace_spans()
@ -93,6 +92,7 @@ class MTex(LabelledString):
self.script_content_spans = self.get_script_content_spans() self.script_content_spans = self.get_script_content_spans()
self.script_spans = self.get_script_spans() self.script_spans = self.get_script_spans()
self.command_repl_items = self.get_command_repl_items() self.command_repl_items = self.get_command_repl_items()
super().parse()
# Toolkits # Toolkits
@ -102,7 +102,7 @@ class MTex(LabelledString):
r, g = divmod(rg, 256) r, g = divmod(rg, 256)
return f"\\color[RGB]{{{r}, {g}, {b}}}" return f"\\color[RGB]{{{r}, {g}, {b}}}"
# Pre-parsing # Parsing
def get_backslash_indices(self) -> list[int]: def get_backslash_indices(self) -> list[int]:
# The latter of `\\` doesn't count. # The latter of `\\` doesn't count.
@ -186,20 +186,18 @@ class MTex(LabelledString):
continue continue
n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name]
span_begin, span_end = cmd_span span_begin, span_end = cmd_span
for _ in n_braces: for _ in range(n_braces):
span_end = brace_spans_dict[min(filter( span_end = brace_spans_dict[min(filter(
lambda index: index >= span_end, lambda index: index >= span_end,
brace_begins brace_begins
))] ))]
if substitute_cmd: if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{black}" repl_str = cmd_name + n_braces * "{black}"
else: else:
repl_str = "" repl_str = ""
result.append(((span_begin, span_end), repl_str)) result.append(((span_begin, span_end), repl_str))
return result return result
# Parsing
def get_skippable_indices(self) -> list[int]: def get_skippable_indices(self) -> list[int]:
return list(it.chain( return list(it.chain(
self.find_indices(r"\s"), self.find_indices(r"\s"),
@ -298,7 +296,7 @@ class MTex(LabelledString):
]) ])
return result return result
# Post-parsing # Selector
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
if not self.brace_spans: if not self.brace_spans:

View File

@ -198,9 +198,9 @@ class SVGMobject(VMobject):
) -> VMobject: ) -> VMobject:
mob.set_style( mob.set_style(
stroke_width=shape.stroke_width, stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex, stroke_color=shape.stroke.hexrgb,
stroke_opacity=shape.stroke.opacity, stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hex, fill_color=shape.fill.hexrgb,
fill_opacity=shape.fill.opacity fill_opacity=shape.fill.opacity
) )
return mob return mob

View File

@ -6,13 +6,12 @@ import os
from pathlib import Path from pathlib import Path
import re import re
from manimpango import MarkupUtils import manimpango
import pygments import pygments
import pygments.formatters import pygments.formatters
import pygments.lexers import pygments.lexers
from manimlib.constants import BLACK from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH
from manimlib.constants import DEFAULT_PIXEL_HEIGHT, DEFAULT_PIXEL_WIDTH
from manimlib.constants import NORMAL from manimlib.constants import NORMAL
from manimlib.logger import log from manimlib.logger import log
from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.svg.labelled_string import LabelledString
@ -46,48 +45,15 @@ if TYPE_CHECKING:
TEXT_MOB_SCALE_FACTOR = 0.0076 TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6 DEFAULT_LINE_SPACING_SCALE = 0.6
# Ensure the canvas is large enough to hold all glyphs.
DEFAULT_CANVAS_WIDTH = 16384
DEFAULT_CANVAS_HEIGHT = 16384
# See https://docs.gtk.org/Pango/pango_markup.html # See https://docs.gtk.org/Pango/pango_markup.html
# A tag containing two aliases will cause warning,
# so only use the first key of each group of aliases.
MARKUP_KEY_ALIAS_LIST = (
("font", "font_desc"),
("font_family", "face"),
("font_size", "size"),
("font_style", "style"),
("font_weight", "weight"),
("font_variant", "variant"),
("font_stretch", "stretch"),
("font_features",),
("foreground", "fgcolor", "color"),
("background", "bgcolor"),
("alpha", "fgalpha"),
("background_alpha", "bgalpha"),
("underline",),
("underline_color",),
("overline",),
("overline_color",),
("rise",),
("baseline_shift",),
("font_scale",),
("strikethrough",),
("strikethrough_color",),
("fallback",),
("lang",),
("letter_spacing",),
("gravity",),
("gravity_hint",),
("show",),
("insert_hyphens",),
("allow_breaks",),
("line_height",),
("text_transform",),
("segment",),
)
MARKUP_COLOR_KEYS = ( MARKUP_COLOR_KEYS = (
"foreground", "foreground", "fgcolor", "color",
"background", "background", "bgcolor",
"underline_color", "underline_color",
"overline_color", "overline_color",
"strikethrough_color" "strikethrough_color"
@ -125,7 +91,7 @@ class MarkupText(LabelledString):
"justify": False, "justify": False,
"indent": 0, "indent": 0,
"alignment": "LEFT", "alignment": "LEFT",
"line_width_factor": None, "line_width": None,
"font": "", "font": "",
"slant": NORMAL, "slant": NORMAL,
"weight": NORMAL, "weight": NORMAL,
@ -146,9 +112,7 @@ class MarkupText(LabelledString):
if not self.font: if not self.font:
self.font = get_customization()["style"]["font"] self.font = get_customization()["style"]["font"]
if self.is_markup: if self.is_markup:
validate_error = MarkupUtils.validate(text) self.validate_markup_string(text)
if validate_error:
raise ValueError(validate_error)
self.text = text self.text = text
super().__init__(text, **kwargs) super().__init__(text, **kwargs)
@ -178,7 +142,7 @@ class MarkupText(LabelledString):
self.justify, self.justify,
self.indent, self.indent,
self.alignment, self.alignment,
self.line_width_factor, self.line_width,
self.font, self.font,
self.slant, self.slant,
self.weight, self.weight,
@ -205,23 +169,32 @@ class MarkupText(LabelledString):
kwargs[short_name] = kwargs.pop(long_name) kwargs[short_name] = kwargs.pop(long_name)
def get_file_path_by_content(self, content: str) -> str: def get_file_path_by_content(self, content: str) -> str:
hash_content = str((
content,
self.justify,
self.indent,
self.alignment,
self.line_width
))
svg_file = os.path.join( svg_file = os.path.join(
get_text_dir(), tex_hash(content) + ".svg" get_text_dir(), tex_hash(hash_content) + ".svg"
) )
if not os.path.exists(svg_file): if not os.path.exists(svg_file):
self.markup_to_svg(content, svg_file) self.markup_to_svg(content, svg_file)
return svg_file return svg_file
def markup_to_svg(self, markup_str: str, file_name: str) -> str: def markup_to_svg(self, markup_str: str, file_name: str) -> str:
self.validate_markup_string(markup_str)
# `manimpango` is under construction, # `manimpango` is under construction,
# so the following code is intended to suit its interface # so the following code is intended to suit its interface
alignment = _Alignment(self.alignment) alignment = _Alignment(self.alignment)
if self.line_width_factor is None: if self.line_width is None:
pango_width = -1 pango_width = -1
else: else:
pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH
return MarkupUtils.text2svg( return manimpango.MarkupUtils.text2svg(
text=markup_str, text=markup_str,
font="", # Already handled font="", # Already handled
slant="NORMAL", # Already handled slant="NORMAL", # Already handled
@ -232,8 +205,8 @@ class MarkupText(LabelledString):
file_name=file_name, file_name=file_name,
START_X=0, START_X=0,
START_Y=0, START_Y=0,
width=DEFAULT_PIXEL_WIDTH, width=DEFAULT_CANVAS_WIDTH,
height=DEFAULT_PIXEL_HEIGHT, height=DEFAULT_CANVAS_HEIGHT,
justify=self.justify, justify=self.justify,
indent=self.indent, indent=self.indent,
line_spacing=None, # Already handled line_spacing=None, # Already handled
@ -241,11 +214,22 @@ class MarkupText(LabelledString):
pango_width=pango_width pango_width=pango_width
) )
def pre_parse(self) -> None: @staticmethod
super().pre_parse() def validate_markup_string(markup_str: str) -> None:
validate_error = manimpango.MarkupUtils.validate(markup_str)
if not validate_error:
return
raise ValueError(
f"Invalid markup string \"{markup_str}\"\n"
f"{validate_error}"
)
def parse(self) -> None:
self.global_attr_dict = self.get_global_attr_dict()
self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() self.tag_pairs_from_markup = self.get_tag_pairs_from_markup()
self.tag_spans = self.get_tag_spans() self.tag_spans = self.get_tag_spans()
self.items_from_markup = self.get_items_from_markup() self.items_from_markup = self.get_items_from_markup()
super().parse()
# Toolkits # Toolkits
@ -256,7 +240,24 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items() for key, val in attr_dict.items()
]) ])
# Pre-parsing # Parsing
def get_global_attr_dict(self) -> dict[str, str]:
result = {
"font_size": str(self.font_size * 1024),
"foreground": self.int_to_hex(self.base_color_int),
"font_family": self.font,
"font_style": self.slant,
"font_weight": self.weight,
}
# `line_height` attribute is supported since Pango 1.50.
if tuple(map(int, manimpango.pango_version().split("."))) >= (1, 50):
result.update({
"line_height": str((
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6),
})
return result
def get_tag_pairs_from_markup( def get_tag_pairs_from_markup(
self self
@ -264,8 +265,8 @@ class MarkupText(LabelledString):
if not self.is_markup: if not self.is_markup:
return [] return []
tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>"""
attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2"""
begin_match_obj_stack = [] begin_match_obj_stack = []
match_obj_pairs = [] match_obj_pairs = []
for match_obj in re.finditer(tag_pattern, self.string): for match_obj in re.finditer(tag_pattern, self.string):
@ -275,16 +276,10 @@ class MarkupText(LabelledString):
match_obj_pairs.append( match_obj_pairs.append(
(begin_match_obj_stack.pop(), match_obj) (begin_match_obj_stack.pop(), match_obj)
) )
if begin_match_obj_stack:
raise ValueError("Unclosed tag(s) detected")
result = [] result = []
for begin_match_obj, end_match_obj in match_obj_pairs: for begin_match_obj, end_match_obj in match_obj_pairs:
tag_name = begin_match_obj.group(2) tag_name = begin_match_obj.group(2)
if tag_name != end_match_obj.group(2):
raise ValueError("Unmatched tag names")
if end_match_obj.group(3):
raise ValueError("Attributes shan't exist in ending tags")
if tag_name == "span": if tag_name == "span":
attr_dict = { attr_dict = {
match.group(1): match.group(3) match.group(1): match.group(3)
@ -292,14 +287,8 @@ class MarkupText(LabelledString):
attr_pattern, begin_match_obj.group(3) attr_pattern, begin_match_obj.group(3)
) )
} }
elif tag_name in MARKUP_TAG_CONVERSION_DICT.keys():
if begin_match_obj.group(3):
raise ValueError(
f"Attributes shan't exist in tag '{tag_name}'"
)
attr_dict = MARKUP_TAG_CONVERSION_DICT[tag_name].copy()
else: else:
raise ValueError(f"Unknown tag: '{tag_name}'") attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {})
result.append( result.append(
(begin_match_obj.span(), end_match_obj.span(), attr_dict) (begin_match_obj.span(), end_match_obj.span(), attr_dict)
@ -320,8 +309,6 @@ class MarkupText(LabelledString):
in self.tag_pairs_from_markup in self.tag_pairs_from_markup
] ]
# Parsing
def get_skippable_indices(self) -> list[int]: def get_skippable_indices(self) -> list[int]:
return self.find_indices(r"\s") return self.find_indices(r"\s")
@ -335,20 +322,9 @@ class MarkupText(LabelledString):
return [span for span, _ in self.items_from_markup] return [span for span, _ in self.items_from_markup]
def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]:
result = [ return list(it.chain(
(self.full_span, { self.items_from_markup,
"line_height": str(( [
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6),
"font_family": self.font,
"font_size": str(self.font_size * 1024),
"font_style": self.slant,
"font_weight": self.weight,
"foreground": self.int_to_hex(self.base_color_int)
}),
(self.full_span, self.global_config),
*self.items_from_markup,
*[
(span, {key: val}) (span, {key: val})
for t2x_dict, key in ( for t2x_dict, key in (
(self.t2c, "foreground"), (self.t2c, "foreground"),
@ -359,24 +335,12 @@ class MarkupText(LabelledString):
for selector, val in t2x_dict.items() for selector, val in t2x_dict.items()
for span in self.find_spans_by_selector(selector) for span in self.find_spans_by_selector(selector)
], ],
*[ [
(span, local_config) (span, local_config)
for selector, local_config in self.local_configs.items() for selector, local_config in self.local_configs.items()
for span in self.find_spans_by_selector(selector) for span in self.find_spans_by_selector(selector)
] ]
] ))
key_conversion_dict = {
key: key_alias_list[0]
for key_alias_list in MARKUP_KEY_ALIAS_LIST
for key in key_alias_list
}
return [
(span, {
key_conversion_dict[key.lower()]: val
for key, val in attr_dict.items()
})
for span, attr_dict in result
]
def get_label_span_list(self) -> list[Span]: def get_label_span_list(self) -> list[Span]:
interval_spans = sorted(it.chain( interval_spans = sorted(it.chain(
@ -398,14 +362,20 @@ class MarkupText(LabelledString):
])) ]))
def get_content(self, is_labelled: bool) -> str: def get_content(self, is_labelled: bool) -> str:
predefined_items = [
(self.full_span, self.global_attr_dict),
(self.full_span, self.global_config),
*self.specified_items
]
if is_labelled: if is_labelled:
attr_dict_items = list(it.chain( attr_dict_items = list(it.chain(
[ [
(span, { (span, {
key: BLACK if key in MARKUP_COLOR_KEYS else val key:
"black" if key.lower() in MARKUP_COLOR_KEYS else val
for key, val in attr_dict.items() for key, val in attr_dict.items()
}) })
for span, attr_dict in self.specified_items for span, attr_dict in predefined_items
], ],
[ [
(span, {"foreground": self.int_to_hex(label + 1)}) (span, {"foreground": self.int_to_hex(label + 1)})
@ -414,7 +384,7 @@ class MarkupText(LabelledString):
)) ))
else: else:
attr_dict_items = list(it.chain( attr_dict_items = list(it.chain(
self.specified_items, predefined_items,
[ [
(span, {}) (span, {})
for span in self.label_span_list for span in self.label_span_list
@ -425,7 +395,7 @@ class MarkupText(LabelledString):
f"<span {self.get_attr_dict_str(attr_dict)}>", f"<span {self.get_attr_dict_str(attr_dict)}>",
"</span>" "</span>"
)) ))
for span, attr_dict in attr_dict_items for span, attr_dict in attr_dict_items if attr_dict
] ]
repl_items = [ repl_items = [
(tag_span, "") for tag_span in self.tag_spans (tag_span, "") for tag_span in self.tag_spans
@ -445,7 +415,7 @@ class MarkupText(LabelledString):
) )
return self.get_replaced_substr(self.full_span, span_repl_dict) return self.get_replaced_substr(self.full_span, span_repl_dict)
# Post-parsing # Selector
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
repl_dict = dict.fromkeys(self.tag_spans, "") repl_dict = dict.fromkeys(self.tag_spans, "")

View File

@ -135,10 +135,14 @@ def make_even(
def hash_obj(obj: object) -> int: def hash_obj(obj: object) -> int:
if isinstance(obj, dict): if isinstance(obj, dict):
new_obj = {k: hash_obj(v) for k, v in obj.items()} return hash(tuple(sorted([
return hash(tuple(frozenset(sorted(new_obj.items())))) (hash_obj(k), hash_obj(v)) for k, v in obj.items()
])))
if isinstance(obj, (set, tuple, list)): if isinstance(obj, set):
return hash(tuple(sorted(hash_obj(e) for e in obj)))
if isinstance(obj, (tuple, list)):
return hash(tuple(hash_obj(e) for e in obj)) return hash(tuple(hash_obj(e) for e in obj))
if isinstance(obj, Color): if isinstance(obj, Color):