mirror of
https://github.com/3b1b/manim.git
synced 2025-08-02 19:46:21 +08:00

One could argue that a pattern of "arg: dict | None = None" followed by "self.param = arg or dict()" is better, but that would make for an inconsistent pattern in cases where the default argument is not None.
350 lines
12 KiB
Python
350 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from xml.etree import ElementTree as ET
|
|
|
|
import numpy as np
|
|
import svgelements as se
|
|
|
|
from manimlib.constants import RIGHT
|
|
from manimlib.logger import log
|
|
from manimlib.mobject.geometry import Circle
|
|
from manimlib.mobject.geometry import Line
|
|
from manimlib.mobject.geometry import Polygon
|
|
from manimlib.mobject.geometry import Polyline
|
|
from manimlib.mobject.geometry import Rectangle
|
|
from manimlib.mobject.geometry import RoundedRectangle
|
|
from manimlib.mobject.types.vectorized_mobject import VMobject
|
|
from manimlib.utils.directories import get_mobject_data_dir
|
|
from manimlib.utils.images import get_full_vector_image_path
|
|
from manimlib.utils.iterables import hash_obj
|
|
from manimlib.utils.simple_functions import hash_string
|
|
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from manimlib.typing import ManimColor
|
|
|
|
|
|
|
|
SVG_HASH_TO_MOB_MAP: dict[int, VMobject] = {}
|
|
|
|
|
|
def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
|
|
return np.array([x, y, 0.0])
|
|
|
|
|
|
class SVGMobject(VMobject):
|
|
file_name: str = ""
|
|
|
|
def __init__(
|
|
self,
|
|
file_name: str = "",
|
|
should_center: bool = True,
|
|
height: float | None = 2.0,
|
|
width: float | None = None,
|
|
# Style that overrides the original svg
|
|
color: ManimColor = None,
|
|
fill_color: ManimColor = None,
|
|
fill_opacity: float | None = None,
|
|
stroke_width: float | None = 0.0,
|
|
stroke_color: ManimColor = None,
|
|
stroke_opacity: float | None = None,
|
|
# Style that fills only when not specified
|
|
# If None, regarded as default values from svg standard
|
|
svg_default: dict = dict(
|
|
color=None,
|
|
opacity=None,
|
|
fill_color=None,
|
|
fill_opacity=None,
|
|
stroke_width=None,
|
|
stroke_color=None,
|
|
stroke_opacity=None,
|
|
),
|
|
path_string_config: dict = dict(),
|
|
**kwargs
|
|
):
|
|
self.file_name = file_name or self.file_name
|
|
self.svg_default = dict(svg_default)
|
|
self.path_string_config = dict(path_string_config)
|
|
self.height = height
|
|
|
|
super().__init__(**kwargs )
|
|
self.init_svg_mobject()
|
|
|
|
# Rather than passing style into super().__init__
|
|
# do it after svg has been taken in
|
|
self.set_style(
|
|
fill_color=color or fill_color,
|
|
fill_opacity=fill_opacity,
|
|
stroke_color=color or stroke_color,
|
|
stroke_width=stroke_width,
|
|
stroke_opacity=stroke_opacity,
|
|
)
|
|
|
|
# Initialize position
|
|
if should_center:
|
|
self.center()
|
|
if height is not None:
|
|
self.set_height(height)
|
|
if width is not None:
|
|
self.set_width(width)
|
|
|
|
def init_svg_mobject(self) -> None:
|
|
hash_val = hash_obj(self.hash_seed)
|
|
if hash_val in SVG_HASH_TO_MOB_MAP:
|
|
mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
|
|
self.add(*mob)
|
|
return
|
|
|
|
self.generate_mobject()
|
|
SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
|
|
|
|
@property
|
|
def hash_seed(self) -> tuple:
|
|
# Returns data which can uniquely represent the result of `init_points`.
|
|
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
|
|
return (
|
|
self.__class__.__name__,
|
|
self.svg_default,
|
|
self.path_string_config,
|
|
self.file_name
|
|
)
|
|
|
|
def generate_mobject(self) -> None:
|
|
file_path = self.get_file_path()
|
|
element_tree = ET.parse(file_path)
|
|
new_tree = self.modify_xml_tree(element_tree)
|
|
# Create a temporary svg file to dump modified svg to be parsed
|
|
root, ext = os.path.splitext(file_path)
|
|
modified_file_path = root + "_" + ext
|
|
new_tree.write(modified_file_path)
|
|
|
|
svg = se.SVG.parse(modified_file_path)
|
|
os.remove(modified_file_path)
|
|
|
|
mobjects = self.get_mobjects_from(svg)
|
|
self.add(*mobjects)
|
|
self.flip(RIGHT) # Flip y
|
|
|
|
def get_file_path(self) -> str:
|
|
if self.file_name is None:
|
|
raise Exception("Must specify file for SVGMobject")
|
|
return get_full_vector_image_path(self.file_name)
|
|
|
|
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
|
|
config_style_attrs = self.generate_config_style_dict()
|
|
style_keys = (
|
|
"fill",
|
|
"fill-opacity",
|
|
"stroke",
|
|
"stroke-opacity",
|
|
"stroke-width",
|
|
"style"
|
|
)
|
|
root = element_tree.getroot()
|
|
style_attrs = {
|
|
k: v
|
|
for k, v in root.attrib.items()
|
|
if k in style_keys
|
|
}
|
|
|
|
# Ignore other attributes in case that svgelements cannot parse them
|
|
SVG_XMLNS = "{http://www.w3.org/2000/svg}"
|
|
new_root = ET.Element("svg")
|
|
config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs)
|
|
root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs)
|
|
root_style_node.extend(root)
|
|
return ET.ElementTree(new_root)
|
|
|
|
def generate_config_style_dict(self) -> dict[str, str]:
|
|
keys_converting_dict = {
|
|
"fill": ("color", "fill_color"),
|
|
"fill-opacity": ("opacity", "fill_opacity"),
|
|
"stroke": ("color", "stroke_color"),
|
|
"stroke-opacity": ("opacity", "stroke_opacity"),
|
|
"stroke-width": ("stroke_width",)
|
|
}
|
|
svg_default_dict = self.svg_default
|
|
result = {}
|
|
for svg_key, style_keys in keys_converting_dict.items():
|
|
for style_key in style_keys:
|
|
if svg_default_dict[style_key] is None:
|
|
continue
|
|
result[svg_key] = str(svg_default_dict[style_key])
|
|
return result
|
|
|
|
def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
|
|
result = []
|
|
for shape in svg.elements():
|
|
if isinstance(shape, (se.Group, se.Use)):
|
|
continue
|
|
elif isinstance(shape, se.Path):
|
|
mob = self.path_to_mobject(shape)
|
|
elif isinstance(shape, se.SimpleLine):
|
|
mob = self.line_to_mobject(shape)
|
|
elif isinstance(shape, se.Rect):
|
|
mob = self.rect_to_mobject(shape)
|
|
elif isinstance(shape, (se.Circle, se.Ellipse)):
|
|
mob = self.ellipse_to_mobject(shape)
|
|
elif isinstance(shape, se.Polygon):
|
|
mob = self.polygon_to_mobject(shape)
|
|
elif isinstance(shape, se.Polyline):
|
|
mob = self.polyline_to_mobject(shape)
|
|
# elif isinstance(shape, se.Text):
|
|
# mob = self.text_to_mobject(shape)
|
|
elif type(shape) == se.SVGElement:
|
|
continue
|
|
else:
|
|
log.warning("Unsupported element type: %s", type(shape))
|
|
continue
|
|
if not mob.has_points():
|
|
continue
|
|
if isinstance(shape, se.GraphicObject):
|
|
self.apply_style_to_mobject(mob, shape)
|
|
if isinstance(shape, se.Transformable) and shape.apply:
|
|
self.handle_transform(mob, shape.transform)
|
|
result.append(mob)
|
|
return result
|
|
|
|
@staticmethod
|
|
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
|
|
mat = np.array([
|
|
[matrix.a, matrix.c],
|
|
[matrix.b, matrix.d]
|
|
])
|
|
vec = np.array([matrix.e, matrix.f, 0.0])
|
|
mob.apply_matrix(mat)
|
|
mob.shift(vec)
|
|
return mob
|
|
|
|
@staticmethod
|
|
def apply_style_to_mobject(
|
|
mob: VMobject,
|
|
shape: se.GraphicObject
|
|
) -> VMobject:
|
|
mob.set_style(
|
|
stroke_width=shape.stroke_width,
|
|
stroke_color=shape.stroke.hexrgb,
|
|
stroke_opacity=shape.stroke.opacity,
|
|
fill_color=shape.fill.hexrgb,
|
|
fill_opacity=shape.fill.opacity
|
|
)
|
|
return mob
|
|
|
|
def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
|
|
return VMobjectFromSVGPath(path, **self.path_string_config)
|
|
|
|
def line_to_mobject(self, line: se.SimpleLine) -> Line:
|
|
return Line(
|
|
start=_convert_point_to_3d(line.x1, line.y1),
|
|
end=_convert_point_to_3d(line.x2, line.y2)
|
|
)
|
|
|
|
def rect_to_mobject(self, rect: se.Rect) -> Rectangle:
|
|
if rect.rx == 0 or rect.ry == 0:
|
|
mob = Rectangle(
|
|
width=rect.width,
|
|
height=rect.height,
|
|
)
|
|
else:
|
|
mob = RoundedRectangle(
|
|
width=rect.width,
|
|
height=rect.height * rect.rx / rect.ry,
|
|
corner_radius=rect.rx
|
|
)
|
|
mob.stretch_to_fit_height(rect.height)
|
|
mob.shift(_convert_point_to_3d(
|
|
rect.x + rect.width / 2,
|
|
rect.y + rect.height / 2
|
|
))
|
|
return mob
|
|
|
|
def ellipse_to_mobject(self, ellipse: se.Circle | se.Ellipse) -> Circle:
|
|
mob = Circle(radius=ellipse.rx)
|
|
mob.stretch_to_fit_height(2 * ellipse.ry)
|
|
mob.shift(_convert_point_to_3d(
|
|
ellipse.cx, ellipse.cy
|
|
))
|
|
return mob
|
|
|
|
def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
|
|
points = [
|
|
_convert_point_to_3d(*point)
|
|
for point in polygon
|
|
]
|
|
return Polygon(*points)
|
|
|
|
def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
|
|
points = [
|
|
_convert_point_to_3d(*point)
|
|
for point in polyline
|
|
]
|
|
return Polyline(*points)
|
|
|
|
def text_to_mobject(self, text: se.Text):
|
|
pass
|
|
|
|
|
|
class VMobjectFromSVGPath(VMobject):
|
|
def __init__(
|
|
self,
|
|
path_obj: se.Path,
|
|
long_lines: bool = False,
|
|
should_subdivide_sharp_curves: bool = False,
|
|
should_remove_null_curves: bool = False,
|
|
**kwargs
|
|
):
|
|
# Get rid of arcs
|
|
path_obj.approximate_arcs_with_quads()
|
|
self.path_obj = path_obj
|
|
self.long_lines = long_lines
|
|
self.should_subdivide_sharp_curves = should_subdivide_sharp_curves
|
|
self.should_remove_null_curves = should_remove_null_curves
|
|
super().__init__(**kwargs)
|
|
|
|
def init_points(self) -> None:
|
|
# After a given svg_path has been converted into points, the result
|
|
# will be saved to a file so that future calls for the same path
|
|
# don't need to retrace the same computation.
|
|
path_string = self.path_obj.d()
|
|
path_hash = hash_string(path_string)
|
|
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
|
|
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
|
|
|
|
if os.path.exists(points_filepath) and os.path.exists(tris_filepath):
|
|
self.set_points(np.load(points_filepath))
|
|
self.triangulation = np.load(tris_filepath)
|
|
self.needs_new_triangulation = False
|
|
else:
|
|
self.handle_commands()
|
|
if self.should_subdivide_sharp_curves:
|
|
# For a healthy triangulation later
|
|
self.subdivide_sharp_curves()
|
|
if self.should_remove_null_curves:
|
|
# Get rid of any null curves
|
|
self.set_points(self.get_points_without_null_curves())
|
|
# Save to a file for future use
|
|
np.save(points_filepath, self.get_points())
|
|
np.save(tris_filepath, self.get_triangulation())
|
|
|
|
def handle_commands(self) -> None:
|
|
segment_class_to_func_map = {
|
|
se.Move: (self.start_new_path, ("end",)),
|
|
se.Close: (self.close_path, ()),
|
|
se.Line: (self.add_line_to, ("end",)),
|
|
se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")),
|
|
se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
|
|
}
|
|
for segment in self.path_obj:
|
|
segment_class = segment.__class__
|
|
func, attr_names = segment_class_to_func_map[segment_class]
|
|
points = [
|
|
_convert_point_to_3d(*segment.__getattribute__(attr_name))
|
|
for attr_name in attr_names
|
|
]
|
|
func(*points)
|
|
|
|
# Get rid of the side effect of trailing "Z M" commands.
|
|
if self.has_new_path_started():
|
|
self.resize_points(self.get_num_points() - 1)
|