mirror of
https://github.com/3b1b/manim.git
synced 2025-07-29 13:03:31 +08:00
Merge pull request #1712 from 3b1b/fix-svg
Improve handling of SVG transform and Some refactors
This commit is contained in:
@ -4,6 +4,7 @@ import random
|
||||
import sys
|
||||
import moderngl
|
||||
from functools import wraps
|
||||
from collections import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -596,7 +597,10 @@ class Mobject(object):
|
||||
Otherwise, if about_point is given a value, scaling is done with
|
||||
respect to that point.
|
||||
"""
|
||||
scale_factor = max(scale_factor, min_scale_factor)
|
||||
if isinstance(scale_factor, Iterable):
|
||||
scale_factor = np.array(scale_factor).clip(min=min_scale_factor)
|
||||
else:
|
||||
scale_factor = max(scale_factor, min_scale_factor)
|
||||
self.apply_points_function(
|
||||
lambda points: scale_factor * points,
|
||||
about_point=about_point,
|
||||
|
@ -1,14 +1,13 @@
|
||||
import itertools as it
|
||||
import re
|
||||
import string
|
||||
import warnings
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from xml.dom import minidom
|
||||
|
||||
from manimlib.constants import DEFAULT_STROKE_WIDTH
|
||||
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT
|
||||
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
|
||||
from manimlib.constants import BLACK
|
||||
from manimlib.constants import WHITE
|
||||
from manimlib.constants import DEGREES, PI
|
||||
@ -23,6 +22,7 @@ from manimlib.utils.config_ops import digest_config
|
||||
from manimlib.utils.directories import get_mobject_data_dir
|
||||
from manimlib.utils.images import get_full_vector_image_path
|
||||
from manimlib.utils.simple_functions import clip
|
||||
from manimlib.logger import log
|
||||
|
||||
|
||||
def string_to_numbers(num_string):
|
||||
@ -71,8 +71,10 @@ class SVGMobject(VMobject):
|
||||
doc = minidom.parse(self.file_path)
|
||||
self.ref_to_element = {}
|
||||
|
||||
for svg in doc.getElementsByTagName("svg"):
|
||||
mobjects = self.get_mobjects_from(svg)
|
||||
for child in doc.childNodes:
|
||||
if not isinstance(child, minidom.Element): continue
|
||||
if child.tagName != 'svg': continue
|
||||
mobjects = self.get_mobjects_from(child)
|
||||
if self.unpack_groups:
|
||||
self.add(*mobjects)
|
||||
else:
|
||||
@ -107,8 +109,8 @@ class SVGMobject(VMobject):
|
||||
elif element.tagName in ['polygon', 'polyline']:
|
||||
result.append(self.polygon_to_mobject(element))
|
||||
else:
|
||||
log.warning(f"Unsupported element type: {element.tagName}")
|
||||
pass # TODO
|
||||
# warnings.warn("Unknown element type: " + element.tagName)
|
||||
result = [m for m in result if m is not None]
|
||||
self.handle_transforms(element, VGroup(*result))
|
||||
if len(result) > 1 and not self.unpack_groups:
|
||||
@ -131,7 +133,7 @@ class SVGMobject(VMobject):
|
||||
# Remove initial "#" character
|
||||
ref = use_element.getAttribute("xlink:href")[1:]
|
||||
if ref not in self.ref_to_element:
|
||||
warnings.warn(f"{ref} not recognized")
|
||||
log.warning(f"{ref} not recognized")
|
||||
return VGroup()
|
||||
return self.get_mobjects_from(
|
||||
self.ref_to_element[ref]
|
||||
@ -227,7 +229,7 @@ class SVGMobject(VMobject):
|
||||
stroke_width=stroke_width,
|
||||
stroke_color=stroke_color,
|
||||
fill_color=fill_color,
|
||||
fill_opacity=opacity,
|
||||
fill_opacity=fill_opacity,
|
||||
corner_radius=corner_radius
|
||||
)
|
||||
|
||||
@ -235,66 +237,94 @@ class SVGMobject(VMobject):
|
||||
return mob
|
||||
|
||||
def handle_transforms(self, element, mobject):
|
||||
# TODO, this could use some cleaning...
|
||||
x, y = 0, 0
|
||||
try:
|
||||
x = self.attribute_to_float(element.getAttribute('x'))
|
||||
# Flip y
|
||||
y = -self.attribute_to_float(element.getAttribute('y'))
|
||||
mobject.shift([x, y, 0])
|
||||
except Exception:
|
||||
pass
|
||||
x, y = (
|
||||
self.attribute_to_float(element.getAttribute(key))
|
||||
if element.hasAttribute(key)
|
||||
else 0.0
|
||||
for key in ("x", "y")
|
||||
)
|
||||
mobject.shift(x * RIGHT + y * DOWN)
|
||||
|
||||
transform = element.getAttribute('transform')
|
||||
transform_names = [
|
||||
"matrix",
|
||||
"translate", "translateX", "translateY",
|
||||
"scale", "scaleX", "scaleY",
|
||||
"rotate",
|
||||
"skewX", "skewY"
|
||||
]
|
||||
transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names]))
|
||||
number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
|
||||
transforms = transform_pattern.findall(element.getAttribute('transform'))[::-1]
|
||||
|
||||
try: # transform matrix
|
||||
prefix = "matrix("
|
||||
suffix = ")"
|
||||
if not transform.startswith(prefix) or not transform.endswith(suffix):
|
||||
raise Exception()
|
||||
transform = transform[len(prefix):-len(suffix)]
|
||||
transform = string_to_numbers(transform)
|
||||
transform = np.array(transform).reshape([3, 2])
|
||||
x = transform[2][0]
|
||||
y = -transform[2][1]
|
||||
matrix = np.identity(self.dim)
|
||||
matrix[:2, :2] = transform[:2, :]
|
||||
matrix[1] *= -1
|
||||
matrix[:, 1] *= -1
|
||||
for transform in transforms:
|
||||
op_name, op_args = transform.split("(")
|
||||
op_name = op_name.strip()
|
||||
op_args = [float(x) for x in number_pattern.findall(op_args)]
|
||||
|
||||
if op_name == "matrix":
|
||||
self._handle_matrix_transform(mobject, op_name, op_args)
|
||||
elif op_name.startswith("translate"):
|
||||
self._handle_translate_transform(mobject, op_name, op_args)
|
||||
elif op_name.startswith("scale"):
|
||||
self._handle_scale_transform(mobject, op_name, op_args)
|
||||
elif op_name == "rotate":
|
||||
self._handle_rotate_transform(mobject, op_name, op_args)
|
||||
elif op_name.startswith("skew"):
|
||||
self._handle_skew_transform(mobject, op_name, op_args)
|
||||
|
||||
def _handle_matrix_transform(self, mobject, op_name, op_args):
|
||||
transform = np.array(op_args).reshape([3, 2])
|
||||
x = transform[2][0]
|
||||
y = -transform[2][1]
|
||||
matrix = np.identity(self.dim)
|
||||
matrix[:2, :2] = transform[:2, :]
|
||||
matrix[1] *= -1
|
||||
matrix[:, 1] *= -1
|
||||
for mob in mobject.family_members_with_points():
|
||||
mob.apply_matrix(matrix.T)
|
||||
mobject.shift(x * RIGHT + y * UP)
|
||||
|
||||
for mob in mobject.family_members_with_points():
|
||||
mob.apply_matrix(matrix.T)
|
||||
mobject.shift(x * RIGHT + y * UP)
|
||||
except:
|
||||
pass
|
||||
|
||||
try: # transform scale
|
||||
prefix = "scale("
|
||||
suffix = ")"
|
||||
if not transform.startswith(prefix) or not transform.endswith(suffix):
|
||||
raise Exception()
|
||||
transform = transform[len(prefix):-len(suffix)]
|
||||
scale_values = string_to_numbers(transform)
|
||||
if len(scale_values) == 2:
|
||||
scale_x, scale_y = scale_values
|
||||
mobject.scale(np.array([scale_x, scale_y, 1]), about_point=ORIGIN)
|
||||
elif len(scale_values) == 1:
|
||||
scale = scale_values[0]
|
||||
mobject.scale(np.array([scale, scale, 1]), about_point=ORIGIN)
|
||||
except:
|
||||
pass
|
||||
|
||||
try: # transform translate
|
||||
prefix = "translate("
|
||||
suffix = ")"
|
||||
if not transform.startswith(prefix) or not transform.endswith(suffix):
|
||||
raise Exception()
|
||||
transform = transform[len(prefix):-len(suffix)]
|
||||
x, y = string_to_numbers(transform)
|
||||
mobject.shift(x * RIGHT + y * DOWN)
|
||||
except:
|
||||
pass
|
||||
# TODO, ...
|
||||
def _handle_translate_transform(self, mobject, op_name, op_args):
|
||||
if op_name.endswith("X"):
|
||||
x, y = op_args[0], 0
|
||||
elif op_name.endswith("Y"):
|
||||
x, y = 0, op_args[0]
|
||||
else:
|
||||
x, y = op_args
|
||||
mobject.shift(x * RIGHT + y * DOWN)
|
||||
|
||||
def _handle_scale_transform(self, mobject, op_name, op_args):
|
||||
if op_name.endswith("X"):
|
||||
sx, sy = op_args[0], 1
|
||||
elif op_name.endswith("Y"):
|
||||
sx, sy = 1, op_args[0]
|
||||
elif len(op_args) == 2:
|
||||
sx, sy = op_args
|
||||
else:
|
||||
sx = sy = op_args[0]
|
||||
if sx < 0:
|
||||
mobject.flip(UP)
|
||||
sx = -sx
|
||||
if sy < 0:
|
||||
mobject.flip(RIGHT)
|
||||
sy = -sy
|
||||
mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN)
|
||||
|
||||
def _handle_rotate_transform(self, mobject, op_name, op_args):
|
||||
if len(op_args) == 1:
|
||||
mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN)
|
||||
else:
|
||||
deg, x, y = op_args
|
||||
mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0]))
|
||||
|
||||
def _handle_skew_transform(self, mobject, op_name, op_args):
|
||||
rad = op_args[0] * DEGREES
|
||||
if op_name == "skewX":
|
||||
tana = np.tan(rad)
|
||||
self._handle_matrix_transform(mobject, None, [1., 0., tana, 1., 0., 0.])
|
||||
elif op_name == "skewY":
|
||||
tana = np.tan(rad)
|
||||
self._handle_matrix_transform(mobject, None, [1., tana, 0., 1., 0., 0.])
|
||||
|
||||
def flatten(self, input_list):
|
||||
output_list = []
|
||||
@ -378,7 +408,8 @@ class VMobjectFromSVGPathstring(VMobject):
|
||||
|
||||
number_types = np.array(list(number_types_str))
|
||||
n_numbers = len(number_types_str)
|
||||
number_groups = np.array(string_to_numbers(coord_string)).reshape((-1, n_numbers))
|
||||
number_list = _PathStringParser(coord_string, number_types_str).args
|
||||
number_groups = np.array(number_list).reshape((-1, n_numbers))
|
||||
|
||||
for numbers in number_groups:
|
||||
if command.islower():
|
||||
@ -520,9 +551,67 @@ class VMobjectFromSVGPathstring(VMobject):
|
||||
"S": (self.add_smooth_cubic_curve_to, "xyxy"),
|
||||
"Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
|
||||
"T": (self.add_smooth_curve_to, "xy"),
|
||||
"A": (self.add_elliptical_arc_to, "-----xy"),
|
||||
"A": (self.add_elliptical_arc_to, "uuaffxy"),
|
||||
"Z": (self.close_path, ""),
|
||||
}
|
||||
|
||||
def get_original_path_string(self):
|
||||
return self.path_string
|
||||
|
||||
|
||||
class InvalidPathError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class _PathStringParser:
|
||||
# modified from https://github.com/regebro/svg.path/
|
||||
def __init__(self, arguments, rules):
|
||||
self.args = []
|
||||
arguments = bytearray(arguments, "ascii")
|
||||
self._strip_array(arguments)
|
||||
while arguments:
|
||||
for rule in rules:
|
||||
self._rule_to_function_map[rule](arguments)
|
||||
|
||||
@property
|
||||
def _rule_to_function_map(self):
|
||||
return {
|
||||
"x": self._get_number,
|
||||
"y": self._get_number,
|
||||
"a": self._get_number,
|
||||
"u": self._get_unsigned_number,
|
||||
"f": self._get_flag,
|
||||
}
|
||||
|
||||
def _strip_array(self, arg_array):
|
||||
# wsp: (0x9, 0x20, 0xA, 0xC, 0xD) with comma 0x2C
|
||||
# https://www.w3.org/TR/SVG/paths.html#PathDataBNF
|
||||
while arg_array and arg_array[0] in [0x9, 0x20, 0xA, 0xC, 0xD, 0x2C]:
|
||||
arg_array[0:1] = b""
|
||||
|
||||
def _get_number(self, arg_array):
|
||||
pattern = re.compile(rb"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
|
||||
res = pattern.search(arg_array)
|
||||
if not res:
|
||||
raise InvalidPathError(f"Expected a number, got '{arg_array}'")
|
||||
number = float(res.group())
|
||||
self.args.append(number)
|
||||
arg_array[res.start():res.end()] = b""
|
||||
self._strip_array(arg_array)
|
||||
return number
|
||||
|
||||
def _get_unsigned_number(self, arg_array):
|
||||
number = self._get_number(arg_array)
|
||||
if number < 0:
|
||||
raise InvalidPathError(f"Expected an unsigned number, got '{number}'")
|
||||
return number
|
||||
|
||||
def _get_flag(self, arg_array):
|
||||
flag = arg_array[0]
|
||||
if flag != 48 and flag != 49:
|
||||
raise InvalidPathError(f"Expected a flag (0/1), got '{chr(flag)}'")
|
||||
flag -= 48
|
||||
self.args.append(flag)
|
||||
arg_array[0:1] = b""
|
||||
self._strip_array(arg_array)
|
||||
return flag
|
||||
|
@ -382,7 +382,10 @@ class VMobject(Mobject):
|
||||
|
||||
def add_smooth_cubic_curve_to(self, handle, point):
|
||||
self.throw_error_if_no_points()
|
||||
new_handle = self.get_reflection_of_last_handle()
|
||||
if self.get_num_points() == 1:
|
||||
new_handle = self.get_points()[-1]
|
||||
else:
|
||||
new_handle = self.get_reflection_of_last_handle()
|
||||
self.add_cubic_bezier_curve_to(new_handle, handle, point)
|
||||
|
||||
def has_new_path_started(self):
|
||||
|
Reference in New Issue
Block a user