diff --git a/manimlib/mobject/svg/drawings.py b/manimlib/mobject/svg/drawings.py
index d6d87fff..41e9e907 100644
--- a/manimlib/mobject/svg/drawings.py
+++ b/manimlib/mobject/svg/drawings.py
@@ -318,9 +318,6 @@ class Bubble(SVGMobject):
self.content = Mobject()
self.refresh_triangulation()
- def init_colors(self):
- VMobject.init_colors(self)
-
def get_tip(self):
# TODO, find a better way
return self.get_corner(DOWN + self.direction) - 0.6 * self.direction
diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py
index 84c0cbf5..f08dcee7 100644
--- a/manimlib/mobject/svg/mtex_mobject.py
+++ b/manimlib/mobject/svg/mtex_mobject.py
@@ -2,11 +2,11 @@ import itertools as it
import re
from types import MethodType
-from manimlib.constants import BLACK
+from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
-from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
+from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_to_svg_file
@@ -18,25 +18,10 @@ from manimlib.logger import log
SCALE_FACTOR_PER_FONT_POINT = 0.001
-TEX_HASH_TO_MOB_MAP = {}
-
-
def _get_neighbouring_pairs(iterable):
return list(adjacent_pairs(iterable))[:-1]
-class _TexSVG(SVGMobject):
- CONFIG = {
- "color": BLACK,
- "stroke_width": 0,
- "height": None,
- "path_string_config": {
- "should_subdivide_sharp_curves": True,
- "should_remove_null_curves": True,
- },
- }
-
-
class _TexParser(object):
def __init__(self, tex_string, additional_substrings):
self.tex_string = tex_string
@@ -400,10 +385,21 @@ class _TexParser(object):
])
-class MTex(VMobject):
+class _TexSVG(SVGMobject):
CONFIG = {
+ "height": None,
"fill_opacity": 1.0,
"stroke_width": 0,
+ "path_string_config": {
+ "should_subdivide_sharp_curves": True,
+ "should_remove_null_curves": True,
+ },
+ }
+
+
+class MTex(_TexSVG):
+ CONFIG = {
+ "color": WHITE,
"font_size": 48,
"alignment": "\\centering",
"tex_environment": "align*",
@@ -413,65 +409,49 @@ class MTex(VMobject):
}
def __init__(self, tex_string, **kwargs):
- super().__init__(**kwargs)
+ digest_config(self, kwargs)
tex_string = tex_string.strip()
# Prevent from passing an empty string.
if not tex_string:
tex_string = "\\quad"
self.tex_string = tex_string
-
- self.__parser = _TexParser(
+ self.parser = _TexParser(
self.tex_string,
[*self.tex_to_color_map.keys(), *self.isolate]
)
- mob = self.generate_mobject()
- self.add(*mob.copy())
- self.init_colors()
+ super().__init__(**kwargs)
+
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
- @staticmethod
- def color_to_label(color):
- r, g, b = color_to_int_rgb(color)
- rg = r * 256 + g
- return rg * 256 + b
+ @property
+ def hash_seed(self):
+ return (
+ self.__class__.__name__,
+ self.svg_default,
+ self.path_string_config,
+ self.tex_string,
+ self.parser.specified_substrings,
+ self.alignment,
+ self.tex_environment,
+ self.use_plain_tex
+ )
- def generate_mobject(self):
- labelled_tex_string = self.__parser.get_labelled_tex_string()
- labelled_tex_content = self.get_tex_file_content(labelled_tex_string)
- hash_val = hash((labelled_tex_content, self.use_plain_tex))
+ def get_file_path(self):
+ return self._get_file_path(self.use_plain_tex)
- if hash_val in TEX_HASH_TO_MOB_MAP:
- return TEX_HASH_TO_MOB_MAP[hash_val]
-
- if not self.use_plain_tex:
- with display_during_execution(f"Writing \"{self.tex_string}\""):
- labelled_svg_glyphs = self.tex_content_to_glyphs(
- labelled_tex_content
- )
- glyph_labels = [
- self.color_to_label(labelled_glyph.get_fill_color())
- for labelled_glyph in labelled_svg_glyphs
- ]
- mob = self.build_mobject(labelled_svg_glyphs, glyph_labels)
- TEX_HASH_TO_MOB_MAP[hash_val] = mob
- return mob
+ def _get_file_path(self, use_plain_tex):
+ if use_plain_tex:
+ tex_string = self.tex_string
+ else:
+ tex_string = self.parser.get_labelled_tex_string()
+ full_tex = self.get_tex_file_body(tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""):
- labelled_svg_glyphs = self.tex_content_to_glyphs(
- labelled_tex_content
- )
- tex_content = self.get_tex_file_content(self.tex_string)
- svg_glyphs = self.tex_content_to_glyphs(tex_content)
- glyph_labels = [
- self.color_to_label(labelled_glyph.get_fill_color())
- for labelled_glyph in labelled_svg_glyphs
- ]
- mob = self.build_mobject(svg_glyphs, glyph_labels)
- TEX_HASH_TO_MOB_MAP[hash_val] = mob
- return mob
+ file_path = self.tex_to_svg_file_path(full_tex)
+ return file_path
- def get_tex_file_content(self, tex_string):
+ def get_tex_file_body(self, tex_string):
if self.tex_environment:
tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}",
@@ -480,17 +460,38 @@ class MTex(VMobject):
])
if self.alignment:
tex_string = "\n".join([self.alignment, tex_string])
- return tex_string
+
+ tex_config = get_tex_config()
+ return tex_config["tex_body"].replace(
+ tex_config["text_to_replace"],
+ tex_string
+ )
@staticmethod
- def tex_content_to_glyphs(tex_content):
- tex_config = get_tex_config()
- full_tex = tex_config["tex_body"].replace(
- tex_config["text_to_replace"],
- tex_content
- )
- filename = tex_to_svg_file(full_tex)
- return _TexSVG(filename)
+ def tex_to_svg_file_path(tex_file_content):
+ return tex_to_svg_file(tex_file_content)
+
+ def generate_mobject(self):
+ super().generate_mobject()
+
+ if not self.use_plain_tex:
+ labelled_svg_glyphs = self
+ else:
+ file_path = self._get_file_path(use_plain_tex=False)
+ labelled_svg_glyphs = _TexSVG(file_path)
+
+ glyph_labels = [
+ self.color_to_label(labelled_glyph.get_fill_color())
+ for labelled_glyph in labelled_svg_glyphs
+ ]
+ mob = self.build_mobject(self, glyph_labels)
+ self.set_submobjects(mob.submobjects)
+
+ @staticmethod
+ def color_to_label(color):
+ r, g, b = color_to_int_rgb(color)
+ rg = r * 256 + g
+ return rg * 256 + b
def build_mobject(self, svg_glyphs, glyph_labels):
if not svg_glyphs:
@@ -514,11 +515,11 @@ class MTex(VMobject):
submob_labels.append(current_glyph_label)
submobjects.append(submobject)
- indices = self.__parser.get_sorted_submob_indices(submob_labels)
+ indices = self.parser.get_sorted_submob_indices(submob_labels)
rearranged_submobjects = [submobjects[index] for index in indices]
rearranged_labels = [submob_labels[index] for index in indices]
- submob_tex_strings = self.__parser.get_submob_tex_strings(
+ submob_tex_strings = self.parser.get_submob_tex_strings(
rearranged_labels
)
for submob, label, submob_tex in zip(
@@ -531,14 +532,14 @@ class MTex(VMobject):
return VGroup(*rearranged_submobjects)
def get_part_by_tex_spans(self, tex_spans):
- labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans)
+ labels = self.parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter(
lambda submob: submob.submob_label in labels,
self.submobjects
))
def get_part_by_custom_span(self, custom_span):
- tex_spans = self.__parser.find_span_components_of_custom_span(
+ tex_spans = self.parser.find_span_components_of_custom_span(
custom_span
)
if tex_spans is None:
@@ -590,10 +591,10 @@ class MTex(VMobject):
]
def get_specified_substrings(self):
- return self.__parser.get_specified_substrings()
+ return self.parser.get_specified_substrings()
def get_isolated_substrings(self):
- return self.__parser.get_isolated_substrings()
+ return self.parser.get_isolated_substrings()
class MTexText(MTex):
diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py
index fd79dffa..3b8bb6ad 100644
--- a/manimlib/mobject/svg/svg_mobject.py
+++ b/manimlib/mobject/svg/svg_mobject.py
@@ -1,7 +1,7 @@
import os
-import re
import hashlib
import itertools as it
+from xml.etree import ElementTree as ET
import svgelements as se
import numpy as np
@@ -17,9 +17,13 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
+from manimlib.utils.iterables import hash_obj
from manimlib.logger import log
+SVG_HASH_TO_MOB_MAP = {}
+
+
def _convert_point_to_3d(x, y):
return np.array([x, y, 0.0])
@@ -29,8 +33,8 @@ class SVGMobject(VMobject):
"should_center": True,
"height": 2,
"width": None,
- # Must be filled in a subclass, or when called
"file_name": None,
+ # Style that overrides the original svg
"color": None,
"opacity": None,
"fill_color": None,
@@ -38,127 +42,119 @@ class SVGMobject(VMobject):
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
- "path_string_config": {}
+ # Style that fills only when not specified
+ # If None, regarded as default values from svg standard
+ "svg_default": {
+ "color": None,
+ "opacity": None,
+ "fill_color": None,
+ "fill_opacity": None,
+ "stroke_width": None,
+ "stroke_color": None,
+ "stroke_opacity": None,
+ },
+ "path_string_config": {},
}
def __init__(self, file_name=None, **kwargs):
- digest_config(self, kwargs)
- self.file_name = file_name or self.file_name
- if file_name is None:
- raise Exception("Must specify file for SVGMobject")
- self.file_path = get_full_vector_image_path(file_name)
-
super().__init__(**kwargs)
+ self.file_name = file_name or self.file_name
+ self.init_svg_mobject()
+ self.init_colors()
self.move_into_position()
- def move_into_position(self):
- if self.should_center:
- self.center()
- if self.height is not None:
- self.set_height(self.height)
- if self.width is not None:
- self.set_width(self.width)
+ def init_svg_mobject(self):
+ hash_val = hash_obj(self.hash_seed)
+ if hash_val in SVG_HASH_TO_MOB_MAP:
+ mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
+ self.add(*mob)
+ return
- def init_colors(self):
- # Remove fill_color, fill_opacity,
- # stroke_width, stroke_color, stroke_opacity
- # as each submobject may have those values specified in svg file
- self.set_stroke(background=self.draw_stroke_behind_fill)
- self.set_gloss(self.gloss)
- self.set_flat_stroke(self.flat_stroke)
- return self
+ self.generate_mobject()
+ SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
- def init_points(self):
- with open(self.file_path, "r") as svg_file:
- svg_string = svg_file.read()
-
- # Create a temporary svg file to dump modified svg to be parsed
- modified_svg_string = self.modify_svg_file(svg_string)
- modified_file_path = self.file_path.replace(".svg", "_.svg")
- with open(modified_file_path, "w") as modified_svg_file:
- modified_svg_file.write(modified_svg_string)
-
- # `color` attribute handles `currentColor` keyword
- if self.fill_color:
- color = self.fill_color
- elif self.color:
- color = self.color
- else:
- color = "black"
- shapes = se.SVG.parse(
- modified_file_path,
- color=color
+ @property
+ def hash_seed(self):
+ # Returns data which can uniquely represent the result of `init_points`.
+ # The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
+ return (
+ self.__class__.__name__,
+ self.svg_default,
+ self.path_string_config,
+ self.file_name
)
+
+ def generate_mobject(self):
+ file_path = self.get_file_path()
+ element_tree = ET.parse(file_path)
+ new_tree = self.modify_xml_tree(element_tree)
+ # Create a temporary svg file to dump modified svg to be parsed
+ modified_file_path = file_path.replace(".svg", "_.svg")
+ new_tree.write(modified_file_path)
+
+ svg = se.SVG.parse(modified_file_path)
os.remove(modified_file_path)
- mobjects = self.get_mobjects_from(shapes)
+ mobjects = self.get_mobjects_from(svg)
self.add(*mobjects)
self.flip(RIGHT) # Flip y
- self.scale(0.75)
- def modify_svg_file(self, svg_string):
- # svgelements cannot handle em, ex units
- # Convert them using 1em = 16px, 1ex = 0.5em = 8px
- def convert_unit(match_obj):
- number = float(match_obj.group(1))
- unit = match_obj.group(2)
- factor = 16 if unit == "em" else 8
- return str(number * factor) + "px"
+ def get_file_path(self):
+ if self.file_name is None:
+ raise Exception("Must specify file for SVGMobject")
+ return get_full_vector_image_path(self.file_name)
- number_pattern = r"([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)(ex|em)(?![a-zA-Z])"
- result = re.sub(number_pattern, convert_unit, svg_string)
+ def modify_xml_tree(self, element_tree):
+ config_style_dict = self.generate_config_style_dict()
+ style_keys = (
+ "fill",
+ "fill-opacity",
+ "stroke",
+ "stroke-opacity",
+ "stroke-width",
+ "style"
+ )
+ root = element_tree.getroot()
+ root_style_dict = {
+ k: v for k, v in root.attrib.items()
+ if k in style_keys
+ }
- # Add a group tag to set style from configuration
- style_dict = self.generate_context_values_from_config()
- group_tag_begin = ""
- group_tag_end = ""
- begin_insert_index = re.search(r")", result).start(1)
- result = "".join([
- result[:begin_insert_index],
- group_tag_begin,
- result[begin_insert_index:end_insert_index],
- group_tag_end,
- result[end_insert_index:]
- ])
+ new_root = ET.Element("svg", {})
+ config_style_node = ET.SubElement(new_root, "g", config_style_dict)
+ root_style_node = ET.SubElement(config_style_node, "g", root_style_dict)
+ root_style_node.extend(root)
+ return ET.ElementTree(new_root)
- return result
-
- def generate_context_values_from_config(self):
+ def generate_config_style_dict(self):
+ keys_converting_dict = {
+ "fill": ("color", "fill_color"),
+ "fill-opacity": ("opacity", "fill_opacity"),
+ "stroke": ("color", "stroke_color"),
+ "stroke-opacity": ("opacity", "stroke_opacity"),
+ "stroke-width": ("stroke_width",)
+ }
+ svg_default_dict = self.svg_default
result = {}
- if self.stroke_width is not None:
- result["stroke-width"] = self.stroke_width
- if self.color is not None:
- result["fill"] = result["stroke"] = self.color
- if self.fill_color is not None:
- result["fill"] = self.fill_color
- if self.stroke_color is not None:
- result["stroke"] = self.stroke_color
- if self.opacity is not None:
- result["fill-opacity"] = result["stroke-opacity"] = self.opacity
- if self.fill_opacity is not None:
- result["fill-opacity"] = self.fill_opacity
- if self.stroke_opacity is not None:
- result["stroke-opacity"] = self.stroke_opacity
+ for svg_key, style_keys in keys_converting_dict.items():
+ for style_key in style_keys:
+ if svg_default_dict[style_key] is None:
+ continue
+ result[svg_key] = str(svg_default_dict[style_key])
return result
- def get_mobjects_from(self, shape):
- if isinstance(shape, se.Group):
- return list(it.chain(*(
- self.get_mobjects_from(child)
- for child in shape
- )))
-
- mob = self.get_mobject_from(shape)
- if mob is None:
- return []
-
- if isinstance(shape, se.Transformable) and shape.apply:
- self.handle_transform(mob, shape.transform)
- return [mob]
+ def get_mobjects_from(self, svg):
+ result = []
+ for shape in svg.elements():
+ if isinstance(shape, se.Group):
+ continue
+ mob = self.get_mobject_from(shape)
+ if mob is None:
+ continue
+ if isinstance(shape, se.Transformable) and shape.apply:
+ self.handle_transform(mob, shape.transform)
+ result.append(mob)
+ return result
@staticmethod
def handle_transform(mob, matrix):
@@ -265,6 +261,14 @@ class SVGMobject(VMobject):
def text_to_mobject(self, text):
pass
+ def move_into_position(self):
+ if self.should_center:
+ self.center()
+ if self.height is not None:
+ self.set_height(self.height)
+ if self.width is not None:
+ self.set_width(self.width)
+
class VMobjectFromSVGPath(VMobject):
CONFIG = {
@@ -320,4 +324,4 @@ class VMobjectFromSVGPath(VMobject):
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
- func(*points)
\ No newline at end of file
+ func(*points)
diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py
index c81a781b..03dc00d2 100644
--- a/manimlib/mobject/svg/tex_mobject.py
+++ b/manimlib/mobject/svg/tex_mobject.py
@@ -5,7 +5,6 @@ import re
from manimlib.constants import *
from manimlib.mobject.geometry import Line
from manimlib.mobject.svg.svg_mobject import SVGMobject
-from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.config_ops import digest_config
from manimlib.utils.tex_file_writing import tex_to_svg_file
@@ -16,55 +15,50 @@ from manimlib.utils.tex_file_writing import display_during_execution
SCALE_FACTOR_PER_FONT_POINT = 0.001
-tex_string_with_color_to_mob_map = {}
-
-
-class SingleStringTex(VMobject):
+class SingleStringTex(SVGMobject):
CONFIG = {
+ "height": None,
"fill_opacity": 1.0,
"stroke_width": 0,
- "should_center": True,
+ "svg_default": {
+ "color": WHITE,
+ },
+ "path_string_config": {
+ "should_subdivide_sharp_curves": True,
+ "should_remove_null_curves": True,
+ },
"font_size": 48,
- "height": None,
- "organize_left_to_right": False,
"alignment": "\\centering",
"math_mode": True,
+ "organize_left_to_right": False,
}
def __init__(self, tex_string, **kwargs):
- super().__init__(**kwargs)
- assert(isinstance(tex_string, str))
+ assert isinstance(tex_string, str)
self.tex_string = tex_string
- if tex_string not in tex_string_with_color_to_mob_map:
- full_tex = self.get_tex_file_body(tex_string)
- filename = tex_to_svg_file(full_tex)
- svg_mob = SVGMobject(
- filename,
- height=None,
- color=self.color,
- stroke_width=self.stroke_width,
- path_string_config={
- "should_subdivide_sharp_curves": True,
- "should_remove_null_curves": True,
- }
- )
- tex_string_with_color_to_mob_map[(self.color, tex_string)] = svg_mob
- self.add(*(
- sm.copy()
- for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)]
- ))
- self.init_colors()
+ super().__init__(**kwargs)
if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
if self.organize_left_to_right:
self.organize_submobjects_left_to_right()
- def init_colors(self):
- self.set_stroke(background=self.draw_stroke_behind_fill)
- self.set_gloss(self.gloss)
- self.set_flat_stroke(self.flat_stroke)
- return self
+ @property
+ def hash_seed(self):
+ return (
+ self.__class__.__name__,
+ self.svg_default,
+ self.path_string_config,
+ self.tex_string,
+ self.alignment,
+ self.math_mode
+ )
+
+ def get_file_path(self):
+ full_tex = self.get_tex_file_body(self.tex_string)
+ with display_during_execution(f"Writing \"{self.tex_string}\""):
+ file_path = tex_to_svg_file(full_tex)
+ return file_path
def get_tex_file_body(self, tex_string):
new_tex = self.get_modified_expression(tex_string)
diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py
index e412f08d..6052648d 100644
--- a/manimlib/mobject/svg/text_mobject.py
+++ b/manimlib/mobject/svg/text_mobject.py
@@ -71,8 +71,6 @@ class Text(SVGMobject):
PangoUtils.remove_last_M(file_name)
self.remove_empty_path(file_name)
SVGMobject.__init__(self, file_name, **kwargs)
- if self.color:
- self.set_fill(self.color)
self.text = text
if self.disable_ligatures:
self.apply_space_chars()
diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py
index d94af506..8858a3c5 100644
--- a/manimlib/utils/iterables.py
+++ b/manimlib/utils/iterables.py
@@ -139,3 +139,14 @@ def remove_nones(sequence):
def concatenate_lists(*list_of_lists):
return [item for l in list_of_lists for item in l]
+
+
+def hash_obj(obj):
+ if isinstance(obj, dict):
+ new_obj = {k: hash_obj(v) for k, v in obj.items()}
+ return hash(tuple(frozenset(sorted(new_obj.items()))))
+
+ if isinstance(obj, (set, tuple, list)):
+ return hash(tuple(hash_obj(e) for e in obj))
+
+ return hash(obj)
diff --git a/manimlib/utils/tex_file_writing.py b/manimlib/utils/tex_file_writing.py
index 0b0e0ba9..cfc157f4 100644
--- a/manimlib/utils/tex_file_writing.py
+++ b/manimlib/utils/tex_file_writing.py
@@ -126,6 +126,9 @@ def dvi_to_svg(dvi_file, regen_if_exists=False):
def display_during_execution(message):
# Only show top line
to_print = message.split("\n")[0]
+ max_characters = os.get_terminal_size().columns - 1
+ if len(to_print) > max_characters:
+ to_print = to_print[:max_characters - 3] + "..."
try:
print(to_print, end="\r")
yield