mirror of
https://github.com/3b1b/manim.git
synced 2025-08-02 19:46:21 +08:00
Refactor labelled_string.py
This commit is contained in:
@ -4,10 +4,9 @@ from abc import ABC, abstractmethod
|
||||
import itertools as it
|
||||
import re
|
||||
|
||||
from manimlib.constants import BLACK, WHITE
|
||||
from manimlib.constants import WHITE
|
||||
from manimlib.mobject.svg.svg_mobject import SVGMobject
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
from manimlib.utils.color import color_to_int_rgb
|
||||
from manimlib.utils.color import color_to_rgb
|
||||
from manimlib.utils.color import rgb_to_hex
|
||||
from manimlib.utils.config_ops import digest_config
|
||||
@ -25,7 +24,10 @@ if TYPE_CHECKING:
|
||||
Span = tuple[int, int]
|
||||
|
||||
|
||||
class _StringSVG(SVGMobject):
|
||||
class LabelledString(SVGMobject, ABC):
|
||||
"""
|
||||
An abstract base class for `MTex` and `MarkupText`
|
||||
"""
|
||||
CONFIG = {
|
||||
"height": None,
|
||||
"stroke_width": 0,
|
||||
@ -34,16 +36,6 @@ class _StringSVG(SVGMobject):
|
||||
"should_subdivide_sharp_curves": True,
|
||||
"should_remove_null_curves": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class LabelledString(_StringSVG, ABC):
|
||||
"""
|
||||
An abstract base class for `MTex` and `MarkupText`
|
||||
"""
|
||||
CONFIG = {
|
||||
"base_color": WHITE,
|
||||
"use_plain_file": False,
|
||||
"isolate": [],
|
||||
}
|
||||
|
||||
@ -51,14 +43,11 @@ class LabelledString(_StringSVG, ABC):
|
||||
self.string = string
|
||||
digest_config(self, kwargs)
|
||||
|
||||
# Convert `base_color` to hex code.
|
||||
self.base_color = rgb_to_hex(color_to_rgb(
|
||||
self.base_color \
|
||||
or self.svg_default.get("color", None) \
|
||||
or self.svg_default.get("fill_color", None) \
|
||||
self.base_color_int = self.color_to_int(
|
||||
self.svg_default.get("fill_color") \
|
||||
or self.svg_default.get("color") \
|
||||
or WHITE
|
||||
))
|
||||
self.svg_default["fill_color"] = BLACK
|
||||
)
|
||||
|
||||
self.pre_parse()
|
||||
self.parse()
|
||||
@ -66,7 +55,7 @@ class LabelledString(_StringSVG, ABC):
|
||||
self.post_parse()
|
||||
|
||||
def get_file_path(self) -> str:
|
||||
return self.get_file_path_(use_plain_file=False)
|
||||
return self.get_file_path_(use_plain_file=True)
|
||||
|
||||
def get_file_path_(self, use_plain_file: bool) -> str:
|
||||
content = self.get_content(use_plain_file)
|
||||
@ -79,22 +68,34 @@ class LabelledString(_StringSVG, ABC):
|
||||
def generate_mobject(self) -> None:
|
||||
super().generate_mobject()
|
||||
|
||||
submob_labels = [
|
||||
self.color_to_label(submob.get_fill_color())
|
||||
for submob in self.submobjects
|
||||
]
|
||||
if self.use_plain_file or self.has_predefined_local_colors:
|
||||
file_path = self.get_file_path_(use_plain_file=True)
|
||||
plain_svg = _StringSVG(
|
||||
file_path,
|
||||
svg_default=self.svg_default,
|
||||
path_string_config=self.path_string_config
|
||||
)
|
||||
self.set_submobjects(plain_svg.submobjects)
|
||||
if self.label_span_list:
|
||||
file_path = self.get_file_path_(use_plain_file=False)
|
||||
labelled_svg = SVGMobject(file_path)
|
||||
submob_color_ints = [
|
||||
self.color_to_int(submob.get_fill_color())
|
||||
for submob in labelled_svg.submobjects
|
||||
]
|
||||
else:
|
||||
self.set_fill(self.base_color)
|
||||
for submob, label in zip(self.submobjects, submob_labels):
|
||||
submob.label = label
|
||||
submob_color_ints = [0] * len(self.submobjects)
|
||||
|
||||
if len(self.submobjects) != len(submob_color_ints):
|
||||
raise ValueError(
|
||||
"Cannot align submobjects of the labelled svg "
|
||||
"to the original svg"
|
||||
)
|
||||
|
||||
unrecognized_color_ints = remove_list_redundancies(sorted(filter(
|
||||
lambda color_int: color_int > len(self.label_span_list),
|
||||
submob_color_ints
|
||||
)))
|
||||
if unrecognized_color_ints:
|
||||
raise ValueError(
|
||||
"Unrecognized color label(s) detected: "
|
||||
f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}"
|
||||
)
|
||||
|
||||
for submob, color_int in zip(self.submobjects, submob_color_ints):
|
||||
submob.label = color_int - 1
|
||||
|
||||
def pre_parse(self) -> None:
|
||||
self.string_len = len(self.string)
|
||||
@ -283,31 +284,14 @@ class LabelledString(_StringSVG, ABC):
|
||||
return index
|
||||
|
||||
@staticmethod
|
||||
def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int:
|
||||
r, g, b = rgb_tuple
|
||||
rg = r * 256 + g
|
||||
return rg * 256 + b
|
||||
|
||||
@staticmethod
|
||||
def int_to_rgb(rgb_int: int) -> tuple[int, int, int]:
|
||||
rg, b = divmod(rgb_int, 256)
|
||||
r, g = divmod(rg, 256)
|
||||
return r, g, b
|
||||
def color_to_int(color: ManimColor) -> int:
|
||||
hex_code = rgb_to_hex(color_to_rgb(color))
|
||||
return int(hex_code[1:], 16)
|
||||
|
||||
@staticmethod
|
||||
def int_to_hex(rgb_int: int) -> str:
|
||||
return "#{:06x}".format(rgb_int).upper()
|
||||
|
||||
@staticmethod
|
||||
def hex_to_int(rgb_hex: str) -> int:
|
||||
return int(rgb_hex[1:], 16)
|
||||
|
||||
@staticmethod
|
||||
def color_to_label(color: ManimColor) -> int:
|
||||
rgb_tuple = color_to_int_rgb(color)
|
||||
rgb = LabelledString.rgb_to_int(rgb_tuple)
|
||||
return rgb - 1
|
||||
|
||||
# Parsing
|
||||
|
||||
@abstractmethod
|
||||
@ -387,10 +371,6 @@ class LabelledString(_StringSVG, ABC):
|
||||
def get_content(self, use_plain_file: bool) -> str:
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def has_predefined_local_colors(self) -> bool:
|
||||
return False
|
||||
|
||||
# Post-parsing
|
||||
|
||||
def get_labelled_submobjects(self) -> list[VMobject]:
|
||||
|
@ -47,8 +47,6 @@ class MTex(LabelledString):
|
||||
self.__class__.__name__,
|
||||
self.svg_default,
|
||||
self.path_string_config,
|
||||
self.base_color,
|
||||
self.use_plain_file,
|
||||
self.isolate,
|
||||
self.tex_string,
|
||||
self.alignment,
|
||||
@ -78,13 +76,9 @@ class MTex(LabelledString):
|
||||
|
||||
@staticmethod
|
||||
def get_color_command_str(rgb_int: int) -> str:
|
||||
rgb_tuple = MTex.int_to_rgb(rgb_int)
|
||||
return "".join([
|
||||
"\\color[RGB]",
|
||||
"{",
|
||||
",".join(map(str, rgb_tuple)),
|
||||
"}"
|
||||
])
|
||||
rg, b = divmod(rgb_int, 256)
|
||||
r, g = divmod(rg, 256)
|
||||
return f"\\color[RGB]{{{r}, {g}, {b}}}"
|
||||
|
||||
# Pre-parsing
|
||||
|
||||
@ -276,15 +270,11 @@ class MTex(LabelledString):
|
||||
result = "\n".join([self.alignment, result])
|
||||
if use_plain_file:
|
||||
result = "\n".join([
|
||||
self.get_color_command_str(self.hex_to_int(self.base_color)),
|
||||
self.get_color_command_str(self.base_color_int),
|
||||
result
|
||||
])
|
||||
return result
|
||||
|
||||
@property
|
||||
def has_predefined_local_colors(self) -> bool:
|
||||
return bool(self.command_repl_items)
|
||||
|
||||
# Post-parsing
|
||||
|
||||
def get_cleaned_substr(self, span: Span) -> str:
|
||||
|
@ -27,7 +27,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from colour import Color
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
|
||||
from manimlib.mobject.types.vectorized_mobject import VMobject
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
@ -43,7 +43,7 @@ DEFAULT_LINE_SPACING_SCALE = 0.6
|
||||
# See https://docs.gtk.org/Pango/pango_markup.html
|
||||
# A tag containing two aliases will cause warning,
|
||||
# so only use the first key of each group of aliases.
|
||||
SPAN_ATTR_KEY_ALIAS_LIST = (
|
||||
MARKUP_KEY_ALIAS_LIST = (
|
||||
("font", "font_desc"),
|
||||
("font_family", "face"),
|
||||
("font_size", "size"),
|
||||
@ -77,19 +77,14 @@ SPAN_ATTR_KEY_ALIAS_LIST = (
|
||||
("text_transform",),
|
||||
("segment",),
|
||||
)
|
||||
COLOR_RELATED_KEYS = (
|
||||
MARKUP_COLOR_KEYS = (
|
||||
"foreground",
|
||||
"background",
|
||||
"underline_color",
|
||||
"overline_color",
|
||||
"strikethrough_color"
|
||||
"background",
|
||||
"underline_color",
|
||||
"overline_color",
|
||||
"strikethrough_color"
|
||||
)
|
||||
SPAN_ATTR_KEY_CONVERSION = {
|
||||
key: key_alias_list[0]
|
||||
for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
|
||||
for key in key_alias_list
|
||||
}
|
||||
TAG_TO_ATTR_DICT = {
|
||||
MARKUP_TAG_CONVERSION_DICT = {
|
||||
"b": {"font_weight": "bold"},
|
||||
"big": {"font_size": "larger"},
|
||||
"i": {"font_style": "italic"},
|
||||
@ -166,8 +161,6 @@ class MarkupText(LabelledString):
|
||||
self.__class__.__name__,
|
||||
self.svg_default,
|
||||
self.path_string_config,
|
||||
self.base_color,
|
||||
self.use_plain_file,
|
||||
self.isolate,
|
||||
self.text,
|
||||
self.is_markup,
|
||||
@ -258,7 +251,7 @@ class MarkupText(LabelledString):
|
||||
|
||||
@staticmethod
|
||||
def merge_attr_dicts(
|
||||
attr_dict_items: list[Span, str, Any]
|
||||
attr_dict_items: list[tuple[Span, dict[str, str]]]
|
||||
) -> list[tuple[Span, dict[str, str]]]:
|
||||
index_seq = [0]
|
||||
attr_dict_list = [{}]
|
||||
@ -344,12 +337,12 @@ class MarkupText(LabelledString):
|
||||
attr_pattern, begin_match_obj.group(3)
|
||||
)
|
||||
}
|
||||
elif tag_name in TAG_TO_ATTR_DICT.keys():
|
||||
elif tag_name in MARKUP_TAG_CONVERSION_DICT.keys():
|
||||
if begin_match_obj.group(3):
|
||||
raise ValueError(
|
||||
f"Attributes shan't exist in tag '{tag_name}'"
|
||||
)
|
||||
attr_dict = TAG_TO_ATTR_DICT[tag_name].copy()
|
||||
attr_dict = MARKUP_TAG_CONVERSION_DICT[tag_name].copy()
|
||||
else:
|
||||
raise ValueError(f"Unknown tag: '{tag_name}'")
|
||||
|
||||
@ -358,13 +351,13 @@ class MarkupText(LabelledString):
|
||||
)
|
||||
return result
|
||||
|
||||
def get_global_dict_from_config(self) -> dict[str, Any]:
|
||||
def get_global_dict_from_config(self) -> dict[str, str]:
|
||||
result = {
|
||||
"line_height": (
|
||||
"line_height": str((
|
||||
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
|
||||
) * 0.6,
|
||||
) * 0.6),
|
||||
"font_family": self.font,
|
||||
"font_size": self.font_size * 1024,
|
||||
"font_size": str(self.font_size * 1024),
|
||||
"font_style": self.slant,
|
||||
"font_weight": self.weight
|
||||
}
|
||||
@ -382,7 +375,7 @@ class MarkupText(LabelledString):
|
||||
|
||||
def get_local_dicts_from_config(
|
||||
self
|
||||
) -> list[Span, dict[str, Any]]:
|
||||
) -> list[Span, dict[str, str]]:
|
||||
return [
|
||||
(span, {key: val})
|
||||
for t2x_dict, key in (
|
||||
@ -405,9 +398,14 @@ class MarkupText(LabelledString):
|
||||
*self.local_dicts_from_markup,
|
||||
*self.local_dicts_from_config
|
||||
]
|
||||
key_conversion_dict = {
|
||||
key: key_alias_list[0]
|
||||
for key_alias_list in MARKUP_KEY_ALIAS_LIST
|
||||
for key in key_alias_list
|
||||
}
|
||||
return [
|
||||
(span, {
|
||||
SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val)
|
||||
key_conversion_dict[key.lower()]: val
|
||||
for key, val in attr_dict.items()
|
||||
})
|
||||
for span, attr_dict in attr_dict_items
|
||||
@ -442,7 +440,7 @@ class MarkupText(LabelledString):
|
||||
return []
|
||||
|
||||
def get_internal_specified_spans(self) -> list[Span]:
|
||||
return [span for span, _ in self.local_dicts_from_markup]
|
||||
return []
|
||||
|
||||
def get_external_specified_spans(self) -> list[Span]:
|
||||
return [span for span, _ in self.local_dicts_from_config]
|
||||
@ -468,7 +466,9 @@ class MarkupText(LabelledString):
|
||||
def get_content(self, use_plain_file: bool) -> str:
|
||||
if use_plain_file:
|
||||
attr_dict_items = [
|
||||
(self.full_span, {"foreground": self.base_color}),
|
||||
(self.full_span, {
|
||||
"foreground": self.int_to_hex(self.base_color_int)
|
||||
}),
|
||||
*self.predefined_attr_dicts,
|
||||
*[
|
||||
(span, {})
|
||||
@ -480,7 +480,7 @@ class MarkupText(LabelledString):
|
||||
(self.full_span, {"foreground": BLACK}),
|
||||
*[
|
||||
(span, {
|
||||
key: BLACK if key in COLOR_RELATED_KEYS else val
|
||||
key: BLACK if key in MARKUP_COLOR_KEYS else val
|
||||
for key, val in attr_dict.items()
|
||||
})
|
||||
for span, attr_dict in self.predefined_attr_dicts
|
||||
@ -502,14 +502,6 @@ class MarkupText(LabelledString):
|
||||
)
|
||||
return self.get_replaced_substr(self.full_span, span_repl_dict)
|
||||
|
||||
@property
|
||||
def has_predefined_local_colors(self) -> bool:
|
||||
return any([
|
||||
key in COLOR_RELATED_KEYS
|
||||
for _, attr_dict in self.predefined_attr_dicts
|
||||
for key in attr_dict.keys()
|
||||
])
|
||||
|
||||
# Method alias
|
||||
|
||||
def get_parts_by_text(self, text: str, **kwargs) -> VGroup:
|
||||
|
Reference in New Issue
Block a user