Refactor labelled_string.py

This commit is contained in:
YishiMichael
2022-04-14 21:07:31 +08:00
parent 0c1e5b337b
commit eec6b01a72
3 changed files with 71 additions and 109 deletions

View File

@ -4,10 +4,9 @@ from abc import ABC, abstractmethod
import itertools as it
import re
from manimlib.constants import BLACK, WHITE
from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config
@ -25,7 +24,10 @@ if TYPE_CHECKING:
Span = tuple[int, int]
class _StringSVG(SVGMobject):
class LabelledString(SVGMobject, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = {
"height": None,
"stroke_width": 0,
@ -34,16 +36,6 @@ class _StringSVG(SVGMobject):
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class LabelledString(_StringSVG, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = {
"base_color": WHITE,
"use_plain_file": False,
"isolate": [],
}
@ -51,14 +43,11 @@ class LabelledString(_StringSVG, ABC):
self.string = string
digest_config(self, kwargs)
# Convert `base_color` to hex code.
self.base_color = rgb_to_hex(color_to_rgb(
self.base_color \
or self.svg_default.get("color", None) \
or self.svg_default.get("fill_color", None) \
self.base_color_int = self.color_to_int(
self.svg_default.get("fill_color") \
or self.svg_default.get("color") \
or WHITE
))
self.svg_default["fill_color"] = BLACK
)
self.pre_parse()
self.parse()
@ -66,7 +55,7 @@ class LabelledString(_StringSVG, ABC):
self.post_parse()
def get_file_path(self) -> str:
return self.get_file_path_(use_plain_file=False)
return self.get_file_path_(use_plain_file=True)
def get_file_path_(self, use_plain_file: bool) -> str:
content = self.get_content(use_plain_file)
@ -79,22 +68,34 @@ class LabelledString(_StringSVG, ABC):
def generate_mobject(self) -> None:
super().generate_mobject()
submob_labels = [
self.color_to_label(submob.get_fill_color())
for submob in self.submobjects
]
if self.use_plain_file or self.has_predefined_local_colors:
file_path = self.get_file_path_(use_plain_file=True)
plain_svg = _StringSVG(
file_path,
svg_default=self.svg_default,
path_string_config=self.path_string_config
)
self.set_submobjects(plain_svg.submobjects)
if self.label_span_list:
file_path = self.get_file_path_(use_plain_file=False)
labelled_svg = SVGMobject(file_path)
submob_color_ints = [
self.color_to_int(submob.get_fill_color())
for submob in labelled_svg.submobjects
]
else:
self.set_fill(self.base_color)
for submob, label in zip(self.submobjects, submob_labels):
submob.label = label
submob_color_ints = [0] * len(self.submobjects)
if len(self.submobjects) != len(submob_color_ints):
raise ValueError(
"Cannot align submobjects of the labelled svg "
"to the original svg"
)
unrecognized_color_ints = remove_list_redundancies(sorted(filter(
lambda color_int: color_int > len(self.label_span_list),
submob_color_ints
)))
if unrecognized_color_ints:
raise ValueError(
"Unrecognized color label(s) detected: "
f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}"
)
for submob, color_int in zip(self.submobjects, submob_color_ints):
submob.label = color_int - 1
def pre_parse(self) -> None:
self.string_len = len(self.string)
@ -283,31 +284,14 @@ class LabelledString(_StringSVG, ABC):
return index
@staticmethod
def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int:
r, g, b = rgb_tuple
rg = r * 256 + g
return rg * 256 + b
@staticmethod
def int_to_rgb(rgb_int: int) -> tuple[int, int, int]:
rg, b = divmod(rgb_int, 256)
r, g = divmod(rg, 256)
return r, g, b
def color_to_int(color: ManimColor) -> int:
hex_code = rgb_to_hex(color_to_rgb(color))
return int(hex_code[1:], 16)
@staticmethod
def int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod
def color_to_label(color: ManimColor) -> int:
rgb_tuple = color_to_int_rgb(color)
rgb = LabelledString.rgb_to_int(rgb_tuple)
return rgb - 1
# Parsing
@abstractmethod
@ -387,10 +371,6 @@ class LabelledString(_StringSVG, ABC):
def get_content(self, use_plain_file: bool) -> str:
return ""
@abstractmethod
def has_predefined_local_colors(self) -> bool:
return False
# Post-parsing
def get_labelled_submobjects(self) -> list[VMobject]:

View File

@ -47,8 +47,6 @@ class MTex(LabelledString):
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.base_color,
self.use_plain_file,
self.isolate,
self.tex_string,
self.alignment,
@ -78,13 +76,9 @@ class MTex(LabelledString):
@staticmethod
def get_color_command_str(rgb_int: int) -> str:
rgb_tuple = MTex.int_to_rgb(rgb_int)
return "".join([
"\\color[RGB]",
"{",
",".join(map(str, rgb_tuple)),
"}"
])
rg, b = divmod(rgb_int, 256)
r, g = divmod(rg, 256)
return f"\\color[RGB]{{{r}, {g}, {b}}}"
# Pre-parsing
@ -276,15 +270,11 @@ class MTex(LabelledString):
result = "\n".join([self.alignment, result])
if use_plain_file:
result = "\n".join([
self.get_color_command_str(self.hex_to_int(self.base_color)),
self.get_color_command_str(self.base_color_int),
result
])
return result
@property
def has_predefined_local_colors(self) -> bool:
return bool(self.command_repl_items)
# Post-parsing
def get_cleaned_substr(self, span: Span) -> str:

View File

@ -27,7 +27,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from colour import Color
from typing import Any, Union
from typing import Union
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
@ -43,7 +43,7 @@ DEFAULT_LINE_SPACING_SCALE = 0.6
# See https://docs.gtk.org/Pango/pango_markup.html
# A tag containing two aliases will cause warning,
# so only use the first key of each group of aliases.
SPAN_ATTR_KEY_ALIAS_LIST = (
MARKUP_KEY_ALIAS_LIST = (
("font", "font_desc"),
("font_family", "face"),
("font_size", "size"),
@ -77,19 +77,14 @@ SPAN_ATTR_KEY_ALIAS_LIST = (
("text_transform",),
("segment",),
)
COLOR_RELATED_KEYS = (
MARKUP_COLOR_KEYS = (
"foreground",
"background",
"underline_color",
"overline_color",
"strikethrough_color"
"background",
"underline_color",
"overline_color",
"strikethrough_color"
)
SPAN_ATTR_KEY_CONVERSION = {
key: key_alias_list[0]
for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
for key in key_alias_list
}
TAG_TO_ATTR_DICT = {
MARKUP_TAG_CONVERSION_DICT = {
"b": {"font_weight": "bold"},
"big": {"font_size": "larger"},
"i": {"font_style": "italic"},
@ -166,8 +161,6 @@ class MarkupText(LabelledString):
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.base_color,
self.use_plain_file,
self.isolate,
self.text,
self.is_markup,
@ -258,7 +251,7 @@ class MarkupText(LabelledString):
@staticmethod
def merge_attr_dicts(
attr_dict_items: list[Span, str, Any]
attr_dict_items: list[tuple[Span, dict[str, str]]]
) -> list[tuple[Span, dict[str, str]]]:
index_seq = [0]
attr_dict_list = [{}]
@ -344,12 +337,12 @@ class MarkupText(LabelledString):
attr_pattern, begin_match_obj.group(3)
)
}
elif tag_name in TAG_TO_ATTR_DICT.keys():
elif tag_name in MARKUP_TAG_CONVERSION_DICT.keys():
if begin_match_obj.group(3):
raise ValueError(
f"Attributes shan't exist in tag '{tag_name}'"
)
attr_dict = TAG_TO_ATTR_DICT[tag_name].copy()
attr_dict = MARKUP_TAG_CONVERSION_DICT[tag_name].copy()
else:
raise ValueError(f"Unknown tag: '{tag_name}'")
@ -358,13 +351,13 @@ class MarkupText(LabelledString):
)
return result
def get_global_dict_from_config(self) -> dict[str, Any]:
def get_global_dict_from_config(self) -> dict[str, str]:
result = {
"line_height": (
"line_height": str((
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6,
) * 0.6),
"font_family": self.font,
"font_size": self.font_size * 1024,
"font_size": str(self.font_size * 1024),
"font_style": self.slant,
"font_weight": self.weight
}
@ -382,7 +375,7 @@ class MarkupText(LabelledString):
def get_local_dicts_from_config(
self
) -> list[Span, dict[str, Any]]:
) -> list[Span, dict[str, str]]:
return [
(span, {key: val})
for t2x_dict, key in (
@ -405,9 +398,14 @@ class MarkupText(LabelledString):
*self.local_dicts_from_markup,
*self.local_dicts_from_config
]
key_conversion_dict = {
key: key_alias_list[0]
for key_alias_list in MARKUP_KEY_ALIAS_LIST
for key in key_alias_list
}
return [
(span, {
SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val)
key_conversion_dict[key.lower()]: val
for key, val in attr_dict.items()
})
for span, attr_dict in attr_dict_items
@ -442,7 +440,7 @@ class MarkupText(LabelledString):
return []
def get_internal_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_markup]
return []
def get_external_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_config]
@ -468,7 +466,9 @@ class MarkupText(LabelledString):
def get_content(self, use_plain_file: bool) -> str:
if use_plain_file:
attr_dict_items = [
(self.full_span, {"foreground": self.base_color}),
(self.full_span, {
"foreground": self.int_to_hex(self.base_color_int)
}),
*self.predefined_attr_dicts,
*[
(span, {})
@ -480,7 +480,7 @@ class MarkupText(LabelledString):
(self.full_span, {"foreground": BLACK}),
*[
(span, {
key: BLACK if key in COLOR_RELATED_KEYS else val
key: BLACK if key in MARKUP_COLOR_KEYS else val
for key, val in attr_dict.items()
})
for span, attr_dict in self.predefined_attr_dicts
@ -502,14 +502,6 @@ class MarkupText(LabelledString):
)
return self.get_replaced_substr(self.full_span, span_repl_dict)
@property
def has_predefined_local_colors(self) -> bool:
return any([
key in COLOR_RELATED_KEYS
for _, attr_dict in self.predefined_attr_dicts
for key in attr_dict.keys()
])
# Method alias
def get_parts_by_text(self, text: str, **kwargs) -> VGroup: