Merge pull request #1745 from YishiMichael/master

Reorganize inheriting order and refactor SVGMobject
This commit is contained in:
Grant Sanderson
2022-02-15 10:05:53 -08:00
committed by GitHub
7 changed files with 227 additions and 219 deletions

View File

@ -318,9 +318,6 @@ class Bubble(SVGMobject):
self.content = Mobject() self.content = Mobject()
self.refresh_triangulation() self.refresh_triangulation()
def init_colors(self):
VMobject.init_colors(self)
def get_tip(self): def get_tip(self):
# TODO, find a better way # TODO, find a better way
return self.get_corner(DOWN + self.direction) - 0.6 * self.direction return self.get_corner(DOWN + self.direction) - 0.6 * self.direction

View File

@ -2,11 +2,11 @@ import itertools as it
import re import re
from types import MethodType from types import MethodType
from manimlib.constants import BLACK from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_to_svg_file from manimlib.utils.tex_file_writing import tex_to_svg_file
@ -18,25 +18,10 @@ from manimlib.logger import log
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_HASH_TO_MOB_MAP = {}
def _get_neighbouring_pairs(iterable): def _get_neighbouring_pairs(iterable):
return list(adjacent_pairs(iterable))[:-1] return list(adjacent_pairs(iterable))[:-1]
class _TexSVG(SVGMobject):
CONFIG = {
"color": BLACK,
"stroke_width": 0,
"height": None,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class _TexParser(object): class _TexParser(object):
def __init__(self, tex_string, additional_substrings): def __init__(self, tex_string, additional_substrings):
self.tex_string = tex_string self.tex_string = tex_string
@ -400,10 +385,21 @@ class _TexParser(object):
]) ])
class MTex(VMobject): class _TexSVG(SVGMobject):
CONFIG = { CONFIG = {
"height": None,
"fill_opacity": 1.0, "fill_opacity": 1.0,
"stroke_width": 0, "stroke_width": 0,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class MTex(_TexSVG):
CONFIG = {
"color": WHITE,
"font_size": 48, "font_size": 48,
"alignment": "\\centering", "alignment": "\\centering",
"tex_environment": "align*", "tex_environment": "align*",
@ -413,65 +409,49 @@ class MTex(VMobject):
} }
def __init__(self, tex_string, **kwargs): def __init__(self, tex_string, **kwargs):
super().__init__(**kwargs) digest_config(self, kwargs)
tex_string = tex_string.strip() tex_string = tex_string.strip()
# Prevent from passing an empty string. # Prevent from passing an empty string.
if not tex_string: if not tex_string:
tex_string = "\\quad" tex_string = "\\quad"
self.tex_string = tex_string self.tex_string = tex_string
self.parser = _TexParser(
self.__parser = _TexParser(
self.tex_string, self.tex_string,
[*self.tex_to_color_map.keys(), *self.isolate] [*self.tex_to_color_map.keys(), *self.isolate]
) )
mob = self.generate_mobject() super().__init__(**kwargs)
self.add(*mob.copy())
self.init_colors()
self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@staticmethod @property
def color_to_label(color): def hash_seed(self):
r, g, b = color_to_int_rgb(color) return (
rg = r * 256 + g self.__class__.__name__,
return rg * 256 + b self.svg_default,
self.path_string_config,
def generate_mobject(self): self.tex_string,
labelled_tex_string = self.__parser.get_labelled_tex_string() self.parser.specified_substrings,
labelled_tex_content = self.get_tex_file_content(labelled_tex_string) self.alignment,
hash_val = hash((labelled_tex_content, self.use_plain_tex)) self.tex_environment,
self.use_plain_tex
if hash_val in TEX_HASH_TO_MOB_MAP:
return TEX_HASH_TO_MOB_MAP[hash_val]
if not self.use_plain_tex:
with display_during_execution(f"Writing \"{self.tex_string}\""):
labelled_svg_glyphs = self.tex_content_to_glyphs(
labelled_tex_content
) )
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(labelled_svg_glyphs, glyph_labels)
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
def get_file_path(self):
return self._get_file_path(self.use_plain_tex)
def _get_file_path(self, use_plain_tex):
if use_plain_tex:
tex_string = self.tex_string
else:
tex_string = self.parser.get_labelled_tex_string()
full_tex = self.get_tex_file_body(tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""): with display_during_execution(f"Writing \"{self.tex_string}\""):
labelled_svg_glyphs = self.tex_content_to_glyphs( file_path = self.tex_to_svg_file_path(full_tex)
labelled_tex_content return file_path
)
tex_content = self.get_tex_file_content(self.tex_string)
svg_glyphs = self.tex_content_to_glyphs(tex_content)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(svg_glyphs, glyph_labels)
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
def get_tex_file_content(self, tex_string): def get_tex_file_body(self, tex_string):
if self.tex_environment: if self.tex_environment:
tex_string = "\n".join([ tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}", f"\\begin{{{self.tex_environment}}}",
@ -480,17 +460,38 @@ class MTex(VMobject):
]) ])
if self.alignment: if self.alignment:
tex_string = "\n".join([self.alignment, tex_string]) tex_string = "\n".join([self.alignment, tex_string])
return tex_string
tex_config = get_tex_config()
return tex_config["tex_body"].replace(
tex_config["text_to_replace"],
tex_string
)
@staticmethod @staticmethod
def tex_content_to_glyphs(tex_content): def tex_to_svg_file_path(tex_file_content):
tex_config = get_tex_config() return tex_to_svg_file(tex_file_content)
full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"], def generate_mobject(self):
tex_content super().generate_mobject()
)
filename = tex_to_svg_file(full_tex) if not self.use_plain_tex:
return _TexSVG(filename) labelled_svg_glyphs = self
else:
file_path = self._get_file_path(use_plain_tex=False)
labelled_svg_glyphs = _TexSVG(file_path)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(self, glyph_labels)
self.set_submobjects(mob.submobjects)
@staticmethod
def color_to_label(color):
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
return rg * 256 + b
def build_mobject(self, svg_glyphs, glyph_labels): def build_mobject(self, svg_glyphs, glyph_labels):
if not svg_glyphs: if not svg_glyphs:
@ -514,11 +515,11 @@ class MTex(VMobject):
submob_labels.append(current_glyph_label) submob_labels.append(current_glyph_label)
submobjects.append(submobject) submobjects.append(submobject)
indices = self.__parser.get_sorted_submob_indices(submob_labels) indices = self.parser.get_sorted_submob_indices(submob_labels)
rearranged_submobjects = [submobjects[index] for index in indices] rearranged_submobjects = [submobjects[index] for index in indices]
rearranged_labels = [submob_labels[index] for index in indices] rearranged_labels = [submob_labels[index] for index in indices]
submob_tex_strings = self.__parser.get_submob_tex_strings( submob_tex_strings = self.parser.get_submob_tex_strings(
rearranged_labels rearranged_labels
) )
for submob, label, submob_tex in zip( for submob, label, submob_tex in zip(
@ -531,14 +532,14 @@ class MTex(VMobject):
return VGroup(*rearranged_submobjects) return VGroup(*rearranged_submobjects)
def get_part_by_tex_spans(self, tex_spans): def get_part_by_tex_spans(self, tex_spans):
labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans) labels = self.parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter( return VGroup(*filter(
lambda submob: submob.submob_label in labels, lambda submob: submob.submob_label in labels,
self.submobjects self.submobjects
)) ))
def get_part_by_custom_span(self, custom_span): def get_part_by_custom_span(self, custom_span):
tex_spans = self.__parser.find_span_components_of_custom_span( tex_spans = self.parser.find_span_components_of_custom_span(
custom_span custom_span
) )
if tex_spans is None: if tex_spans is None:
@ -590,10 +591,10 @@ class MTex(VMobject):
] ]
def get_specified_substrings(self): def get_specified_substrings(self):
return self.__parser.get_specified_substrings() return self.parser.get_specified_substrings()
def get_isolated_substrings(self): def get_isolated_substrings(self):
return self.__parser.get_isolated_substrings() return self.parser.get_isolated_substrings()
class MTexText(MTex): class MTexText(MTex):

View File

@ -1,7 +1,7 @@
import os import os
import re
import hashlib import hashlib
import itertools as it import itertools as it
from xml.etree import ElementTree as ET
import svgelements as se import svgelements as se
import numpy as np import numpy as np
@ -17,9 +17,13 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.images import get_full_vector_image_path
from manimlib.utils.iterables import hash_obj
from manimlib.logger import log from manimlib.logger import log
SVG_HASH_TO_MOB_MAP = {}
def _convert_point_to_3d(x, y): def _convert_point_to_3d(x, y):
return np.array([x, y, 0.0]) return np.array([x, y, 0.0])
@ -29,8 +33,8 @@ class SVGMobject(VMobject):
"should_center": True, "should_center": True,
"height": 2, "height": 2,
"width": None, "width": None,
# Must be filled in a subclass, or when called
"file_name": None, "file_name": None,
# Style that overrides the original svg
"color": None, "color": None,
"opacity": None, "opacity": None,
"fill_color": None, "fill_color": None,
@ -38,127 +42,119 @@ class SVGMobject(VMobject):
"stroke_width": None, "stroke_width": None,
"stroke_color": None, "stroke_color": None,
"stroke_opacity": None, "stroke_opacity": None,
"path_string_config": {} # Style that fills only when not specified
# If None, regarded as default values from svg standard
"svg_default": {
"color": None,
"opacity": None,
"fill_color": None,
"fill_opacity": None,
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
},
"path_string_config": {},
} }
def __init__(self, file_name=None, **kwargs): def __init__(self, file_name=None, **kwargs):
digest_config(self, kwargs)
self.file_name = file_name or self.file_name
if file_name is None:
raise Exception("Must specify file for SVGMobject")
self.file_path = get_full_vector_image_path(file_name)
super().__init__(**kwargs) super().__init__(**kwargs)
self.file_name = file_name or self.file_name
self.init_svg_mobject()
self.init_colors()
self.move_into_position() self.move_into_position()
def move_into_position(self): def init_svg_mobject(self):
if self.should_center: hash_val = hash_obj(self.hash_seed)
self.center() if hash_val in SVG_HASH_TO_MOB_MAP:
if self.height is not None: mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
self.set_height(self.height) self.add(*mob)
if self.width is not None: return
self.set_width(self.width)
def init_colors(self): self.generate_mobject()
# Remove fill_color, fill_opacity, SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
# stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
def init_points(self): @property
with open(self.file_path, "r") as svg_file: def hash_seed(self):
svg_string = svg_file.read() # Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
# Create a temporary svg file to dump modified svg to be parsed return (
modified_svg_string = self.modify_svg_file(svg_string) self.__class__.__name__,
modified_file_path = self.file_path.replace(".svg", "_.svg") self.svg_default,
with open(modified_file_path, "w") as modified_svg_file: self.path_string_config,
modified_svg_file.write(modified_svg_string) self.file_name
# `color` attribute handles `currentColor` keyword
if self.fill_color:
color = self.fill_color
elif self.color:
color = self.color
else:
color = "black"
shapes = se.SVG.parse(
modified_file_path,
color=color
) )
def generate_mobject(self):
file_path = self.get_file_path()
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
# Create a temporary svg file to dump modified svg to be parsed
modified_file_path = file_path.replace(".svg", "_.svg")
new_tree.write(modified_file_path)
svg = se.SVG.parse(modified_file_path)
os.remove(modified_file_path) os.remove(modified_file_path)
mobjects = self.get_mobjects_from(shapes) mobjects = self.get_mobjects_from(svg)
self.add(*mobjects) self.add(*mobjects)
self.flip(RIGHT) # Flip y self.flip(RIGHT) # Flip y
self.scale(0.75)
def modify_svg_file(self, svg_string): def get_file_path(self):
# svgelements cannot handle em, ex units if self.file_name is None:
# Convert them using 1em = 16px, 1ex = 0.5em = 8px raise Exception("Must specify file for SVGMobject")
def convert_unit(match_obj): return get_full_vector_image_path(self.file_name)
number = float(match_obj.group(1))
unit = match_obj.group(2)
factor = 16 if unit == "em" else 8
return str(number * factor) + "px"
number_pattern = r"([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)(ex|em)(?![a-zA-Z])" def modify_xml_tree(self, element_tree):
result = re.sub(number_pattern, convert_unit, svg_string) config_style_dict = self.generate_config_style_dict()
style_keys = (
"fill",
"fill-opacity",
"stroke",
"stroke-opacity",
"stroke-width",
"style"
)
root = element_tree.getroot()
root_style_dict = {
k: v for k, v in root.attrib.items()
if k in style_keys
}
# Add a group tag to set style from configuration new_root = ET.Element("svg", {})
style_dict = self.generate_context_values_from_config() config_style_node = ET.SubElement(new_root, "g", config_style_dict)
group_tag_begin = "<g " + " ".join([ root_style_node = ET.SubElement(config_style_node, "g", root_style_dict)
f"{k}=\"{v}\"" root_style_node.extend(root)
for k, v in style_dict.items() return ET.ElementTree(new_root)
]) + ">"
group_tag_end = "</g>"
begin_insert_index = re.search(r"<svg[\s\S]*?>", result).end()
end_insert_index = re.search(r"[\s\S]*(</svg\s*>)", result).start(1)
result = "".join([
result[:begin_insert_index],
group_tag_begin,
result[begin_insert_index:end_insert_index],
group_tag_end,
result[end_insert_index:]
])
return result def generate_config_style_dict(self):
keys_converting_dict = {
def generate_context_values_from_config(self): "fill": ("color", "fill_color"),
"fill-opacity": ("opacity", "fill_opacity"),
"stroke": ("color", "stroke_color"),
"stroke-opacity": ("opacity", "stroke_opacity"),
"stroke-width": ("stroke_width",)
}
svg_default_dict = self.svg_default
result = {} result = {}
if self.stroke_width is not None: for svg_key, style_keys in keys_converting_dict.items():
result["stroke-width"] = self.stroke_width for style_key in style_keys:
if self.color is not None: if svg_default_dict[style_key] is None:
result["fill"] = result["stroke"] = self.color continue
if self.fill_color is not None: result[svg_key] = str(svg_default_dict[style_key])
result["fill"] = self.fill_color
if self.stroke_color is not None:
result["stroke"] = self.stroke_color
if self.opacity is not None:
result["fill-opacity"] = result["stroke-opacity"] = self.opacity
if self.fill_opacity is not None:
result["fill-opacity"] = self.fill_opacity
if self.stroke_opacity is not None:
result["stroke-opacity"] = self.stroke_opacity
return result return result
def get_mobjects_from(self, shape): def get_mobjects_from(self, svg):
result = []
for shape in svg.elements():
if isinstance(shape, se.Group): if isinstance(shape, se.Group):
return list(it.chain(*( continue
self.get_mobjects_from(child)
for child in shape
)))
mob = self.get_mobject_from(shape) mob = self.get_mobject_from(shape)
if mob is None: if mob is None:
return [] continue
if isinstance(shape, se.Transformable) and shape.apply: if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform) self.handle_transform(mob, shape.transform)
return [mob] result.append(mob)
return result
@staticmethod @staticmethod
def handle_transform(mob, matrix): def handle_transform(mob, matrix):
@ -265,6 +261,14 @@ class SVGMobject(VMobject):
def text_to_mobject(self, text): def text_to_mobject(self, text):
pass pass
def move_into_position(self):
if self.should_center:
self.center()
if self.height is not None:
self.set_height(self.height)
if self.width is not None:
self.set_width(self.width)
class VMobjectFromSVGPath(VMobject): class VMobjectFromSVGPath(VMobject):
CONFIG = { CONFIG = {

View File

@ -5,7 +5,6 @@ import re
from manimlib.constants import * from manimlib.constants import *
from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Line
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.tex_file_writing import tex_to_svg_file from manimlib.utils.tex_file_writing import tex_to_svg_file
@ -16,55 +15,50 @@ from manimlib.utils.tex_file_writing import display_during_execution
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
tex_string_with_color_to_mob_map = {} class SingleStringTex(SVGMobject):
class SingleStringTex(VMobject):
CONFIG = { CONFIG = {
"height": None,
"fill_opacity": 1.0, "fill_opacity": 1.0,
"stroke_width": 0, "stroke_width": 0,
"should_center": True, "svg_default": {
"color": WHITE,
},
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
"font_size": 48, "font_size": 48,
"height": None,
"organize_left_to_right": False,
"alignment": "\\centering", "alignment": "\\centering",
"math_mode": True, "math_mode": True,
"organize_left_to_right": False,
} }
def __init__(self, tex_string, **kwargs): def __init__(self, tex_string, **kwargs):
super().__init__(**kwargs) assert isinstance(tex_string, str)
assert(isinstance(tex_string, str))
self.tex_string = tex_string self.tex_string = tex_string
if tex_string not in tex_string_with_color_to_mob_map: super().__init__(**kwargs)
full_tex = self.get_tex_file_body(tex_string)
filename = tex_to_svg_file(full_tex)
svg_mob = SVGMobject(
filename,
height=None,
color=self.color,
stroke_width=self.stroke_width,
path_string_config={
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
}
)
tex_string_with_color_to_mob_map[(self.color, tex_string)] = svg_mob
self.add(*(
sm.copy()
for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)]
))
self.init_colors()
if self.height is None: if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
if self.organize_left_to_right: if self.organize_left_to_right:
self.organize_submobjects_left_to_right() self.organize_submobjects_left_to_right()
def init_colors(self): @property
self.set_stroke(background=self.draw_stroke_behind_fill) def hash_seed(self):
self.set_gloss(self.gloss) return (
self.set_flat_stroke(self.flat_stroke) self.__class__.__name__,
return self self.svg_default,
self.path_string_config,
self.tex_string,
self.alignment,
self.math_mode
)
def get_file_path(self):
full_tex = self.get_tex_file_body(self.tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = tex_to_svg_file(full_tex)
return file_path
def get_tex_file_body(self, tex_string): def get_tex_file_body(self, tex_string):
new_tex = self.get_modified_expression(tex_string) new_tex = self.get_modified_expression(tex_string)

View File

@ -71,8 +71,6 @@ class Text(SVGMobject):
PangoUtils.remove_last_M(file_name) PangoUtils.remove_last_M(file_name)
self.remove_empty_path(file_name) self.remove_empty_path(file_name)
SVGMobject.__init__(self, file_name, **kwargs) SVGMobject.__init__(self, file_name, **kwargs)
if self.color:
self.set_fill(self.color)
self.text = text self.text = text
if self.disable_ligatures: if self.disable_ligatures:
self.apply_space_chars() self.apply_space_chars()

View File

@ -139,3 +139,14 @@ def remove_nones(sequence):
def concatenate_lists(*list_of_lists): def concatenate_lists(*list_of_lists):
return [item for l in list_of_lists for item in l] return [item for l in list_of_lists for item in l]
def hash_obj(obj):
if isinstance(obj, dict):
new_obj = {k: hash_obj(v) for k, v in obj.items()}
return hash(tuple(frozenset(sorted(new_obj.items()))))
if isinstance(obj, (set, tuple, list)):
return hash(tuple(hash_obj(e) for e in obj))
return hash(obj)

View File

@ -126,6 +126,9 @@ def dvi_to_svg(dvi_file, regen_if_exists=False):
def display_during_execution(message): def display_during_execution(message):
# Only show top line # Only show top line
to_print = message.split("\n")[0] to_print = message.split("\n")[0]
max_characters = os.get_terminal_size().columns - 1
if len(to_print) > max_characters:
to_print = to_print[:max_characters - 3] + "..."
try: try:
print(to_print, end="\r") print(to_print, end="\r")
yield yield