diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 1050c3ee..b1621c4f 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -4,6 +4,7 @@ import random import sys import moderngl from functools import wraps +from collections import Iterable import numpy as np @@ -596,7 +597,10 @@ class Mobject(object): Otherwise, if about_point is given a value, scaling is done with respect to that point. """ - scale_factor = max(scale_factor, min_scale_factor) + if isinstance(scale_factor, Iterable): + scale_factor = np.array(scale_factor).clip(min=min_scale_factor) + else: + scale_factor = max(scale_factor, min_scale_factor) self.apply_points_function( lambda points: scale_factor * points, about_point=about_point, diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 3a5260a7..82afa227 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -1,14 +1,13 @@ import itertools as it import re import string -import warnings import os import hashlib from xml.dom import minidom from manimlib.constants import DEFAULT_STROKE_WIDTH -from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT +from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN from manimlib.constants import BLACK from manimlib.constants import WHITE from manimlib.constants import DEGREES, PI @@ -23,6 +22,7 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_mobject_data_dir from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.simple_functions import clip +from manimlib.logger import log def string_to_numbers(num_string): @@ -71,8 +71,10 @@ class SVGMobject(VMobject): doc = minidom.parse(self.file_path) self.ref_to_element = {} - for svg in doc.getElementsByTagName("svg"): - mobjects = self.get_mobjects_from(svg) + for child in doc.childNodes: + if not isinstance(child, minidom.Element): continue + if child.tagName != 'svg': continue + mobjects = self.get_mobjects_from(child) if self.unpack_groups: self.add(*mobjects) else: @@ -107,8 +109,8 @@ class SVGMobject(VMobject): elif element.tagName in ['polygon', 'polyline']: result.append(self.polygon_to_mobject(element)) else: + log.warning(f"Unsupported element type: {element.tagName}") pass # TODO - # warnings.warn("Unknown element type: " + element.tagName) result = [m for m in result if m is not None] self.handle_transforms(element, VGroup(*result)) if len(result) > 1 and not self.unpack_groups: @@ -131,7 +133,7 @@ class SVGMobject(VMobject): # Remove initial "#" character ref = use_element.getAttribute("xlink:href")[1:] if ref not in self.ref_to_element: - warnings.warn(f"{ref} not recognized") + log.warning(f"{ref} not recognized") return VGroup() return self.get_mobjects_from( self.ref_to_element[ref] @@ -227,7 +229,7 @@ class SVGMobject(VMobject): stroke_width=stroke_width, stroke_color=stroke_color, fill_color=fill_color, - fill_opacity=opacity, + fill_opacity=fill_opacity, corner_radius=corner_radius ) @@ -235,66 +237,94 @@ class SVGMobject(VMobject): return mob def handle_transforms(self, element, mobject): - # TODO, this could use some cleaning... - x, y = 0, 0 - try: - x = self.attribute_to_float(element.getAttribute('x')) - # Flip y - y = -self.attribute_to_float(element.getAttribute('y')) - mobject.shift([x, y, 0]) - except Exception: - pass + x, y = ( + self.attribute_to_float(element.getAttribute(key)) + if element.hasAttribute(key) + else 0.0 + for key in ("x", "y") + ) + mobject.shift(x * RIGHT + y * DOWN) - transform = element.getAttribute('transform') + transform_names = [ + "matrix", + "translate", "translateX", "translateY", + "scale", "scaleX", "scaleY", + "rotate", + "skewX", "skewY" + ] + transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names])) + number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?") + transforms = transform_pattern.findall(element.getAttribute('transform'))[::-1] - try: # transform matrix - prefix = "matrix(" - suffix = ")" - if not transform.startswith(prefix) or not transform.endswith(suffix): - raise Exception() - transform = transform[len(prefix):-len(suffix)] - transform = string_to_numbers(transform) - transform = np.array(transform).reshape([3, 2]) - x = transform[2][0] - y = -transform[2][1] - matrix = np.identity(self.dim) - matrix[:2, :2] = transform[:2, :] - matrix[1] *= -1 - matrix[:, 1] *= -1 + for transform in transforms: + op_name, op_args = transform.split("(") + op_name = op_name.strip() + op_args = [float(x) for x in number_pattern.findall(op_args)] + + if op_name == "matrix": + self._handle_matrix_transform(mobject, op_name, op_args) + elif op_name.startswith("translate"): + self._handle_translate_transform(mobject, op_name, op_args) + elif op_name.startswith("scale"): + self._handle_scale_transform(mobject, op_name, op_args) + elif op_name == "rotate": + self._handle_rotate_transform(mobject, op_name, op_args) + elif op_name.startswith("skew"): + self._handle_skew_transform(mobject, op_name, op_args) + + def _handle_matrix_transform(self, mobject, op_name, op_args): + transform = np.array(op_args).reshape([3, 2]) + x = transform[2][0] + y = -transform[2][1] + matrix = np.identity(self.dim) + matrix[:2, :2] = transform[:2, :] + matrix[1] *= -1 + matrix[:, 1] *= -1 + for mob in mobject.family_members_with_points(): + mob.apply_matrix(matrix.T) + mobject.shift(x * RIGHT + y * UP) - for mob in mobject.family_members_with_points(): - mob.apply_matrix(matrix.T) - mobject.shift(x * RIGHT + y * UP) - except: - pass - - try: # transform scale - prefix = "scale(" - suffix = ")" - if not transform.startswith(prefix) or not transform.endswith(suffix): - raise Exception() - transform = transform[len(prefix):-len(suffix)] - scale_values = string_to_numbers(transform) - if len(scale_values) == 2: - scale_x, scale_y = scale_values - mobject.scale(np.array([scale_x, scale_y, 1]), about_point=ORIGIN) - elif len(scale_values) == 1: - scale = scale_values[0] - mobject.scale(np.array([scale, scale, 1]), about_point=ORIGIN) - except: - pass - - try: # transform translate - prefix = "translate(" - suffix = ")" - if not transform.startswith(prefix) or not transform.endswith(suffix): - raise Exception() - transform = transform[len(prefix):-len(suffix)] - x, y = string_to_numbers(transform) - mobject.shift(x * RIGHT + y * DOWN) - except: - pass - # TODO, ... + def _handle_translate_transform(self, mobject, op_name, op_args): + if op_name.endswith("X"): + x, y = op_args[0], 0 + elif op_name.endswith("Y"): + x, y = 0, op_args[0] + else: + x, y = op_args + mobject.shift(x * RIGHT + y * DOWN) + + def _handle_scale_transform(self, mobject, op_name, op_args): + if op_name.endswith("X"): + sx, sy = op_args[0], 1 + elif op_name.endswith("Y"): + sx, sy = 1, op_args[0] + elif len(op_args) == 2: + sx, sy = op_args + else: + sx = sy = op_args[0] + if sx < 0: + mobject.flip(UP) + sx = -sx + if sy < 0: + mobject.flip(RIGHT) + sy = -sy + mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN) + + def _handle_rotate_transform(self, mobject, op_name, op_args): + if len(op_args) == 1: + mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN) + else: + deg, x, y = op_args + mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0])) + + def _handle_skew_transform(self, mobject, op_name, op_args): + rad = op_args[0] * DEGREES + if op_name == "skewX": + tana = np.tan(rad) + self._handle_matrix_transform(mobject, None, [1., 0., tana, 1., 0., 0.]) + elif op_name == "skewY": + tana = np.tan(rad) + self._handle_matrix_transform(mobject, None, [1., tana, 0., 1., 0., 0.]) def flatten(self, input_list): output_list = [] @@ -378,7 +408,8 @@ class VMobjectFromSVGPathstring(VMobject): number_types = np.array(list(number_types_str)) n_numbers = len(number_types_str) - number_groups = np.array(string_to_numbers(coord_string)).reshape((-1, n_numbers)) + number_list = _PathStringParser(coord_string, number_types_str).args + number_groups = np.array(number_list).reshape((-1, n_numbers)) for numbers in number_groups: if command.islower(): @@ -520,9 +551,67 @@ class VMobjectFromSVGPathstring(VMobject): "S": (self.add_smooth_cubic_curve_to, "xyxy"), "Q": (self.add_quadratic_bezier_curve_to, "xyxy"), "T": (self.add_smooth_curve_to, "xy"), - "A": (self.add_elliptical_arc_to, "-----xy"), + "A": (self.add_elliptical_arc_to, "uuaffxy"), "Z": (self.close_path, ""), } def get_original_path_string(self): return self.path_string + + +class InvalidPathError(ValueError): + pass + + +class _PathStringParser: + # modified from https://github.com/regebro/svg.path/ + def __init__(self, arguments, rules): + self.args = [] + arguments = bytearray(arguments, "ascii") + self._strip_array(arguments) + while arguments: + for rule in rules: + self._rule_to_function_map[rule](arguments) + + @property + def _rule_to_function_map(self): + return { + "x": self._get_number, + "y": self._get_number, + "a": self._get_number, + "u": self._get_unsigned_number, + "f": self._get_flag, + } + + def _strip_array(self, arg_array): + # wsp: (0x9, 0x20, 0xA, 0xC, 0xD) with comma 0x2C + # https://www.w3.org/TR/SVG/paths.html#PathDataBNF + while arg_array and arg_array[0] in [0x9, 0x20, 0xA, 0xC, 0xD, 0x2C]: + arg_array[0:1] = b"" + + def _get_number(self, arg_array): + pattern = re.compile(rb"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?") + res = pattern.search(arg_array) + if not res: + raise InvalidPathError(f"Expected a number, got '{arg_array}'") + number = float(res.group()) + self.args.append(number) + arg_array[res.start():res.end()] = b"" + self._strip_array(arg_array) + return number + + def _get_unsigned_number(self, arg_array): + number = self._get_number(arg_array) + if number < 0: + raise InvalidPathError(f"Expected an unsigned number, got '{number}'") + return number + + def _get_flag(self, arg_array): + flag = arg_array[0] + if flag != 48 and flag != 49: + raise InvalidPathError(f"Expected a flag (0/1), got '{chr(flag)}'") + flag -= 48 + self.args.append(flag) + arg_array[0:1] = b"" + self._strip_array(arg_array) + return flag diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index a63a190b..3c7a4326 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -382,7 +382,10 @@ class VMobject(Mobject): def add_smooth_cubic_curve_to(self, handle, point): self.throw_error_if_no_points() - new_handle = self.get_reflection_of_last_handle() + if self.get_num_points() == 1: + new_handle = self.get_points()[-1] + else: + new_handle = self.get_reflection_of_last_handle() self.add_cubic_bezier_curve_to(new_handle, handle, point) def has_new_path_started(self):