from __future__ import annotations from xml.etree import ElementTree as ET import numpy as np import svgelements as se import io from pathlib import Path from manimlib.constants import RIGHT from manimlib.constants import TAU from manimlib.logger import log from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Polygon from manimlib.mobject.geometry import Polyline from manimlib.mobject.geometry import Rectangle from manimlib.mobject.geometry import RoundedRectangle from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import quadratic_bezier_points_for_arc from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.iterables import hash_obj from manimlib.utils.space_ops import rotation_about_z from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.typing import ManimColor, Vect3Array SVG_HASH_TO_MOB_MAP: dict[int, list[VMobject]] = {} PATH_TO_POINTS: dict[str, Vect3Array] = {} def _convert_point_to_3d(x: float, y: float) -> np.ndarray: return np.array([x, y, 0.0]) class SVGMobject(VMobject): file_name: str = "" height: float | None = 2.0 width: float | None = None def __init__( self, file_name: str = "", svg_string: str = "", should_center: bool = True, height: float | None = None, width: float | None = None, # Style that overrides the original svg color: ManimColor = None, fill_color: ManimColor = None, fill_opacity: float | None = None, stroke_width: float | None = 0.0, stroke_color: ManimColor = None, stroke_opacity: float | None = None, # Style that fills only when not specified # If None, regarded as default values from svg standard svg_default: dict = dict( color=None, opacity=None, fill_color=None, fill_opacity=None, stroke_width=None, stroke_color=None, stroke_opacity=None, ), path_string_config: dict = dict(), **kwargs ): if svg_string != "": self.svg_string = svg_string elif file_name != "": self.svg_string = self.file_name_to_svg_string(file_name) elif self.file_name != "": self.file_name_to_svg_string(self.file_name) else: raise Exception("Must specify either a file_name or svg_string SVGMobject") self.svg_default = dict(svg_default) self.path_string_config = dict(path_string_config) super().__init__(**kwargs) self.init_svg_mobject() self.ensure_positive_orientation() # Rather than passing style into super().__init__ # do it after svg has been taken in self.set_style( fill_color=color or fill_color, fill_opacity=fill_opacity, stroke_color=color or stroke_color, stroke_width=stroke_width, stroke_opacity=stroke_opacity, ) # Initialize position height = height or self.height width = width or self.width if should_center: self.center() if height is not None: self.set_height(height) if width is not None: self.set_width(width) def init_svg_mobject(self) -> None: hash_val = hash_obj(self.hash_seed) if hash_val in SVG_HASH_TO_MOB_MAP: submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]] else: submobs = self.mobjects_from_svg_string(self.svg_string) SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs] self.add(*submobs) self.flip(RIGHT) # Flip y @property def hash_seed(self) -> tuple: # Returns data which can uniquely represent the result of `init_points`. # The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`. return ( self.__class__.__name__, self.svg_default, self.path_string_config, self.svg_string ) def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]: element_tree = ET.ElementTree(ET.fromstring(svg_string)) new_tree = self.modify_xml_tree(element_tree) # New svg based on tree contents data_stream = io.BytesIO() new_tree.write(data_stream) data_stream.seek(0) svg = se.SVG.parse(data_stream) data_stream.close() return self.mobjects_from_svg(svg) def file_name_to_svg_string(self, file_name: str) -> str: return Path(get_full_vector_image_path(file_name)).read_text() def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree: config_style_attrs = self.generate_config_style_dict() style_keys = ( "fill", "fill-opacity", "stroke", "stroke-opacity", "stroke-width", "style" ) root = element_tree.getroot() style_attrs = { k: v for k, v in root.attrib.items() if k in style_keys } # Ignore other attributes in case that svgelements cannot parse them SVG_XMLNS = "{http://www.w3.org/2000/svg}" new_root = ET.Element("svg") config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs) root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs) root_style_node.extend(root) return ET.ElementTree(new_root) def generate_config_style_dict(self) -> dict[str, str]: keys_converting_dict = { "fill": ("color", "fill_color"), "fill-opacity": ("opacity", "fill_opacity"), "stroke": ("color", "stroke_color"), "stroke-opacity": ("opacity", "stroke_opacity"), "stroke-width": ("stroke_width",) } svg_default_dict = self.svg_default result = {} for svg_key, style_keys in keys_converting_dict.items(): for style_key in style_keys: if svg_default_dict[style_key] is None: continue result[svg_key] = str(svg_default_dict[style_key]) return result def mobjects_from_svg(self, svg: se.SVG) -> list[VMobject]: result = [] for shape in svg.elements(): if isinstance(shape, (se.Group, se.Use)): continue elif isinstance(shape, se.Path): mob = self.path_to_mobject(shape) elif isinstance(shape, se.SimpleLine): mob = self.line_to_mobject(shape) elif isinstance(shape, se.Rect): mob = self.rect_to_mobject(shape) elif isinstance(shape, (se.Circle, se.Ellipse)): mob = self.ellipse_to_mobject(shape) elif isinstance(shape, se.Polygon): mob = self.polygon_to_mobject(shape) elif isinstance(shape, se.Polyline): mob = self.polyline_to_mobject(shape) # elif isinstance(shape, se.Text): # mob = self.text_to_mobject(shape) elif type(shape) == se.SVGElement: continue else: log.warning("Unsupported element type: %s", type(shape)) continue if not mob.has_points(): continue if isinstance(shape, se.GraphicObject): self.apply_style_to_mobject(mob, shape) if isinstance(shape, se.Transformable) and shape.apply: self.handle_transform(mob, shape.transform) result.append(mob) return result @staticmethod def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject: mat = np.array([ [matrix.a, matrix.c], [matrix.b, matrix.d] ]) vec = np.array([matrix.e, matrix.f, 0.0]) mob.apply_matrix(mat) mob.shift(vec) return mob @staticmethod def apply_style_to_mobject( mob: VMobject, shape: se.GraphicObject ) -> VMobject: mob.set_style( stroke_width=shape.stroke_width, stroke_color=shape.stroke.hexrgb, stroke_opacity=shape.stroke.opacity, fill_color=shape.fill.hexrgb, fill_opacity=shape.fill.opacity ) return mob def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath: return VMobjectFromSVGPath(path, **self.path_string_config) def line_to_mobject(self, line: se.SimpleLine) -> Line: return Line( start=_convert_point_to_3d(line.x1, line.y1), end=_convert_point_to_3d(line.x2, line.y2) ) def rect_to_mobject(self, rect: se.Rect) -> Rectangle: if rect.rx == 0 or rect.ry == 0: mob = Rectangle( width=rect.width, height=rect.height, ) else: mob = RoundedRectangle( width=rect.width, height=rect.height * rect.rx / rect.ry, corner_radius=rect.rx ) mob.stretch_to_fit_height(rect.height) mob.shift(_convert_point_to_3d( rect.x + rect.width / 2, rect.y + rect.height / 2 )) return mob def ellipse_to_mobject(self, ellipse: se.Circle | se.Ellipse) -> Circle: mob = Circle(radius=ellipse.rx) mob.stretch_to_fit_height(2 * ellipse.ry) mob.shift(_convert_point_to_3d( ellipse.cx, ellipse.cy )) return mob def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon: points = [ _convert_point_to_3d(*point) for point in polygon ] return Polygon(*points) def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline: points = [ _convert_point_to_3d(*point) for point in polyline ] return Polyline(*points) def text_to_mobject(self, text: se.Text): pass class VMobjectFromSVGPath(VMobject): def __init__( self, path_obj: se.Path, **kwargs ): # caches (transform.inverse(), rot, shift) self.transform_cache: tuple[se.Matrix, np.ndarray, np.ndarray] | None = None self.path_obj = path_obj super().__init__(**kwargs) def init_points(self) -> None: # After a given svg_path has been converted into points, the result # will be saved so that future calls for the same pathdon't need to # retrace the same computation. path_string = self.path_obj.d() if path_string not in PATH_TO_POINTS: self.handle_commands() # Save for future use PATH_TO_POINTS[path_string] = self.get_points().copy() else: points = PATH_TO_POINTS[path_string] self.set_points(points) def handle_commands(self) -> None: segment_class_to_func_map = { se.Move: (self.start_new_path, ("end",)), se.Close: (self.close_path, ()), se.Line: (lambda p: self.add_line_to(p, allow_null_line=False), ("end",)), se.QuadraticBezier: (lambda c, e: self.add_quadratic_bezier_curve_to(c, e, allow_null_curve=False), ("control", "end")), se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end")) } for segment in self.path_obj: segment_class = segment.__class__ if segment_class is se.Arc: self.handle_arc(segment) else: func, attr_names = segment_class_to_func_map[segment_class] points = [ _convert_point_to_3d(*segment.__getattribute__(attr_name)) for attr_name in attr_names ] func(*points) # Get rid of the side effect of trailing "Z M" commands. if self.has_new_path_started(): self.resize_points(self.get_num_points() - 2) def handle_arc(self, arc: se.Arc) -> None: if self.transform_cache is not None: transform, rot, shift = self.transform_cache else: # The transform obtained in this way considers the combined effect # of all parent group transforms in the SVG. # Therefore, the arc can be transformed inversely using this transform # to correctly compute the arc path before transforming it back. transform = se.Matrix(self.path_obj.values.get('transform', '')) rot = np.array([ [transform.a, transform.c], [transform.b, transform.d] ]) shift = np.array([transform.e, transform.f, 0]) transform.inverse() self.transform_cache = (transform, rot, shift) # Apply inverse transformation to the arc so that its path can be correctly computed arc *= transform # The value of n_components is chosen based on the implementation of VMobject.arc_to n_components = int(np.ceil(8 * abs(arc.sweep) / TAU)) # Obtain the required angular segments on the unit circle arc_points = quadratic_bezier_points_for_arc(arc.sweep, n_components) arc_points @= np.array(rotation_about_z(arc.get_start_t())).T # Transform to an ellipse, considering rotation and translating the ellipse center arc_points[:, 0] *= arc.rx arc_points[:, 1] *= arc.ry arc_points @= np.array(rotation_about_z(arc.get_rotation().as_radians)).T arc_points += [*arc.center, 0] # Transform back arc_points[:, :2] @= rot.T arc_points += shift self.append_points(arc_points[1:])