from __future__ import annotations import os from xml.etree import ElementTree as ET import numpy as np import svgelements as se from manimlib.constants import RIGHT from manimlib.logger import log from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Polygon from manimlib.mobject.geometry import Polyline from manimlib.mobject.geometry import Rectangle from manimlib.mobject.geometry import RoundedRectangle from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.directories import get_mobject_data_dir from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.iterables import hash_obj from manimlib.utils.simple_functions import hash_string SVG_HASH_TO_MOB_MAP: dict[int, VMobject] = {} SVG_XMLNS = "{http://www.w3.org/2000/svg}" SVG_XLINK = "{http://www.w3.org/1999/xlink}" def _convert_point_to_3d(x: float, y: float) -> np.ndarray: return np.array([x, y, 0.0]) class SVGMobject(VMobject): CONFIG = { "should_center": True, "height": 2, "width": None, "file_name": None, # Style that overrides the original svg "color": None, "opacity": None, "fill_color": None, "fill_opacity": None, "stroke_width": None, "stroke_color": None, "stroke_opacity": None, # Style that fills only when not specified # If None, regarded as default values from svg standard "svg_default": { "color": None, "opacity": None, "fill_color": None, "fill_opacity": None, "stroke_width": None, "stroke_color": None, "stroke_opacity": None, }, "path_string_config": {}, } def __init__(self, file_name: str | None = None, **kwargs): super().__init__(**kwargs) self.file_name = file_name or self.file_name self.init_svg_mobject() self.init_colors() self.move_into_position() def init_svg_mobject(self) -> None: hash_val = hash_obj(self.hash_seed) if hash_val in SVG_HASH_TO_MOB_MAP: mob = SVG_HASH_TO_MOB_MAP[hash_val].copy() self.add(*mob) return self.generate_mobject() SVG_HASH_TO_MOB_MAP[hash_val] = self.copy() @property def hash_seed(self) -> tuple: # Returns data which can uniquely represent the result of `init_points`. # The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`. return ( self.__class__.__name__, self.svg_default, self.path_string_config, self.file_name ) def generate_mobject(self) -> None: file_path = self.get_file_path() element_tree = ET.parse(file_path) new_tree = self.modify_xml_tree(element_tree) # Create a temporary svg file to dump modified svg to be parsed root, ext = os.path.splitext(file_path) modified_file_path = root + "_" + ext new_tree.write(modified_file_path) svg = se.SVG.parse(modified_file_path) os.remove(modified_file_path) mobjects = self.get_mobjects_from(svg) self.add(*mobjects) self.flip(RIGHT) # Flip y def get_file_path(self) -> str: if self.file_name is None: raise Exception("Must specify file for SVGMobject") return get_full_vector_image_path(self.file_name) def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree: element_tree = self.expand_use_elements(element_tree) config_style_attrs = self.generate_config_style_dict() style_keys = ( "fill", "fill-opacity", "stroke", "stroke-opacity", "stroke-width", "style" ) root = element_tree.getroot() style_attrs = { k: v for k, v in root.attrib.items() if k in style_keys } # Ignore other attributes in case that svgelements cannot parse them new_root = ET.Element("svg") config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs) root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs) root_style_node.extend(root) return ET.ElementTree(new_root) @staticmethod def expand_use_elements(element_tree: ET.ElementTree) -> ET.ElementTree: # Replace `use` elements with copies of elements they refer to while True: element = element_tree.find(f".//{SVG_XMLNS}use[@{SVG_XLINK}href]") if element is None: break element.tag = f"{SVG_XMLNS}g" attrs = element.attrib href_id = attrs.pop(f"{SVG_XLINK}href")[1:] href_element = element_tree.find(f".//{SVG_XMLNS}*[@id='{href_id}']") if href_element is None: continue attrs.pop("width", None) attrs.pop("height", None) x = attrs.pop("x", "0") y = attrs.pop("y", "0") if not x == y == "0": translate_str = f"translate({x}, {y})" if "transform" in attrs: attrs["transform"] = translate_str + " " + attrs["transform"] else: attrs["transform"] = translate_str shadow_node = ET.SubElement(element, href_element.tag, href_element.attrib) shadow_node.extend(href_element) for defs_element in element_tree.iterfind(f".//{SVG_XMLNS}defs"): defs_element.clear() return element_tree def generate_config_style_dict(self) -> dict[str, str]: keys_converting_dict = { "fill": ("color", "fill_color"), "fill-opacity": ("opacity", "fill_opacity"), "stroke": ("color", "stroke_color"), "stroke-opacity": ("opacity", "stroke_opacity"), "stroke-width": ("stroke_width",) } svg_default_dict = self.svg_default result = {} for svg_key, style_keys in keys_converting_dict.items(): for style_key in style_keys: if svg_default_dict[style_key] is None: continue result[svg_key] = str(svg_default_dict[style_key]) return result def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]: result = [] for shape in svg.elements(): if isinstance(shape, se.Group): continue elif isinstance(shape, se.Path): mob = self.path_to_mobject(shape) elif isinstance(shape, se.SimpleLine): mob = self.line_to_mobject(shape) elif isinstance(shape, se.Rect): mob = self.rect_to_mobject(shape) elif isinstance(shape, se.Circle): mob = self.circle_to_mobject(shape) elif isinstance(shape, se.Ellipse): mob = self.ellipse_to_mobject(shape) elif isinstance(shape, se.Polygon): mob = self.polygon_to_mobject(shape) elif isinstance(shape, se.Polyline): mob = self.polyline_to_mobject(shape) # elif isinstance(shape, se.Text): # mob = self.text_to_mobject(shape) elif type(shape) == se.SVGElement: continue else: log.warning(f"Unsupported element type: {type(shape)}") continue if not mob.has_points(): continue self.apply_style_to_mobject(mob, shape) if isinstance(shape, se.Transformable) and shape.apply: self.handle_transform(mob, shape.transform) result.append(mob) return result @staticmethod def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject: mat = np.array([ [matrix.a, matrix.c], [matrix.b, matrix.d] ]) vec = np.array([matrix.e, matrix.f, 0.0]) mob.apply_matrix(mat) mob.shift(vec) return mob @staticmethod def apply_style_to_mobject( mob: VMobject, shape: se.GraphicObject ) -> VMobject: mob.set_style( stroke_width=shape.stroke_width, stroke_color=shape.stroke.hexrgb, stroke_opacity=shape.stroke.opacity, fill_color=shape.fill.hexrgb, fill_opacity=shape.fill.opacity ) return mob @staticmethod def handle_transform(mob, matrix): mat = np.array([ [matrix.a, matrix.c], [matrix.b, matrix.d] ]) vec = np.array([matrix.e, matrix.f, 0.0]) mob.apply_matrix(mat) mob.shift(vec) return mob def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath: return VMobjectFromSVGPath(path, **self.path_string_config) def line_to_mobject(self, line: se.Line) -> Line: return Line( start=_convert_point_to_3d(line.x1, line.y1), end=_convert_point_to_3d(line.x2, line.y2) ) def rect_to_mobject(self, rect: se.Rect) -> Rectangle: if rect.rx == 0 or rect.ry == 0: mob = Rectangle( width=rect.width, height=rect.height, ) else: mob = RoundedRectangle( width=rect.width, height=rect.height * rect.rx / rect.ry, corner_radius=rect.rx ) mob.stretch_to_fit_height(rect.height) mob.shift(_convert_point_to_3d( rect.x + rect.width / 2, rect.y + rect.height / 2 )) return mob def circle_to_mobject(self, circle: se.Circle) -> Circle: # svgelements supports `rx` & `ry` but `r` mob = Circle(radius=circle.rx) mob.shift(_convert_point_to_3d( circle.cx, circle.cy )) return mob def ellipse_to_mobject(self, ellipse: se.Ellipse) -> Circle: mob = Circle(radius=ellipse.rx) mob.stretch_to_fit_height(2 * ellipse.ry) mob.shift(_convert_point_to_3d( ellipse.cx, ellipse.cy )) return mob def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon: points = [ _convert_point_to_3d(*point) for point in polygon ] return Polygon(*points) def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline: points = [ _convert_point_to_3d(*point) for point in polyline ] return Polyline(*points) def text_to_mobject(self, text: se.Text): pass def move_into_position(self) -> None: if self.should_center: self.center() if self.height is not None: self.set_height(self.height) if self.width is not None: self.set_width(self.width) class VMobjectFromSVGPath(VMobject): CONFIG = { "long_lines": False, "should_subdivide_sharp_curves": False, "should_remove_null_curves": False, } def __init__(self, path_obj: se.Path, **kwargs): # Get rid of arcs path_obj.approximate_arcs_with_quads() self.path_obj = path_obj super().__init__(**kwargs) def init_points(self) -> None: # After a given svg_path has been converted into points, the result # will be saved to a file so that future calls for the same path # don't need to retrace the same computation. path_string = self.path_obj.d() path_hash = hash_string(path_string) points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy") tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy") if os.path.exists(points_filepath) and os.path.exists(tris_filepath): self.set_points(np.load(points_filepath)) self.triangulation = np.load(tris_filepath) self.needs_new_triangulation = False else: self.handle_commands() if self.should_subdivide_sharp_curves: # For a healthy triangulation later self.subdivide_sharp_curves() if self.should_remove_null_curves: # Get rid of any null curves self.set_points(self.get_points_without_null_curves()) # Save to a file for future use np.save(points_filepath, self.get_points()) np.save(tris_filepath, self.get_triangulation()) def handle_commands(self) -> None: segment_class_to_func_map = { se.Move: (self.start_new_path, ("end",)), se.Close: (self.close_path, ()), se.Line: (self.add_line_to, ("end",)), se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")), se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end")) } for segment in self.path_obj: segment_class = segment.__class__ func, attr_names = segment_class_to_func_map[segment_class] points = [ _convert_point_to_3d(*segment.__getattribute__(attr_name)) for attr_name in attr_names ] func(*points) # Get rid of the side effect of trailing "Z M" commands. if self.has_new_path_started(): self.resize_points(self.get_num_points() - 1)