mirror of
https://github.com/3b1b/manim.git
synced 2025-08-02 19:46:21 +08:00
Refactor svg reading
This commit is contained in:
@ -34,9 +34,9 @@ class SVGMobject(VMobject):
|
||||
# Must be filled in in a subclass, or when called
|
||||
"file_name": None,
|
||||
"unpack_groups": True, # if False, creates a hierarchy of VGroups
|
||||
# TODO, style components should be read in, not defaulted
|
||||
"stroke_width": DEFAULT_STROKE_WIDTH,
|
||||
"fill_opacity": 1.0,
|
||||
# "fill_color" : LIGHT_GREY,
|
||||
}
|
||||
|
||||
def __init__(self, file_name=None, **kwargs):
|
||||
@ -47,24 +47,25 @@ class SVGMobject(VMobject):
|
||||
self.move_into_position()
|
||||
|
||||
def ensure_valid_file(self):
|
||||
if self.file_name is None:
|
||||
file_name = self.file_name
|
||||
if file_name is None:
|
||||
raise Exception("Must specify file for SVGMobject")
|
||||
possible_paths = [
|
||||
os.path.join(os.path.join("assets", "svg_images"), self.file_name),
|
||||
os.path.join(os.path.join("assets", "svg_images"), self.file_name + ".svg"),
|
||||
os.path.join(os.path.join("assets", "svg_images"), self.file_name + ".xdv"),
|
||||
self.file_name,
|
||||
os.path.join(os.path.join("assets", "svg_images"), file_name),
|
||||
os.path.join(os.path.join("assets", "svg_images"), file_name + ".svg"),
|
||||
os.path.join(os.path.join("assets", "svg_images"), file_name + ".xdv"),
|
||||
file_name,
|
||||
]
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
self.file_path = path
|
||||
return
|
||||
raise IOError("No file matching %s in image directory" %
|
||||
self.file_name)
|
||||
raise IOError(f"No file matching {file_name} in image directory")
|
||||
|
||||
def generate_points(self):
|
||||
doc = minidom.parse(self.file_path)
|
||||
self.ref_to_element = {}
|
||||
|
||||
for svg in doc.getElementsByTagName("svg"):
|
||||
mobjects = self.get_mobjects_from(svg)
|
||||
if self.unpack_groups:
|
||||
@ -122,7 +123,7 @@ class SVGMobject(VMobject):
|
||||
# Remove initial "#" character
|
||||
ref = use_element.getAttribute("xlink:href")[1:]
|
||||
if ref not in self.ref_to_element:
|
||||
warnings.warn("%s not recognized" % ref)
|
||||
warnings.warn(f"{ref} not recognized")
|
||||
return VGroup()
|
||||
return self.get_mobjects_from(
|
||||
self.ref_to_element[ref]
|
||||
@ -136,15 +137,12 @@ class SVGMobject(VMobject):
|
||||
return float(stripped_attr)
|
||||
|
||||
def polygon_to_mobject(self, polygon_element):
|
||||
# TODO, This seems hacky...
|
||||
path_string = polygon_element.getAttribute("points")
|
||||
for digit in string.digits:
|
||||
path_string = path_string.replace(" " + digit, " L" + digit)
|
||||
path_string = path_string.replace(f" {digit}", f"{digit} L")
|
||||
path_string = "M" + path_string
|
||||
return self.path_string_to_mobject(path_string)
|
||||
|
||||
# <circle class="st1" cx="143.8" cy="268" r="22.6"/>
|
||||
|
||||
def circle_to_mobject(self, circle_element):
|
||||
x, y, r = [
|
||||
self.attribute_to_float(
|
||||
@ -321,111 +319,76 @@ class VMobjectFromSVGPathstring(VMobject):
|
||||
digest_locals(self)
|
||||
VMobject.__init__(self, **kwargs)
|
||||
|
||||
def get_path_commands(self):
|
||||
result = [
|
||||
"M", # moveto
|
||||
"L", # lineto
|
||||
"H", # horizontal lineto
|
||||
"V", # vertical lineto
|
||||
"C", # curveto
|
||||
"S", # smooth curveto
|
||||
"Q", # quadratic Bezier curve
|
||||
"T", # smooth quadratic Bezier curveto
|
||||
"A", # elliptical Arc
|
||||
"Z", # closepath
|
||||
]
|
||||
result += [s.lower() for s in result]
|
||||
return result
|
||||
|
||||
def generate_points(self):
|
||||
pattern = "[%s]" % ("".join(self.get_path_commands()))
|
||||
pairs = list(zip(
|
||||
self.relative_point = ORIGIN
|
||||
for command, coord_string in self.get_commands_and_coord_strings():
|
||||
new_points = self.string_to_points(command, coord_string)
|
||||
self.handle_command(command, new_points)
|
||||
# SVG treats y-coordinate differently
|
||||
self.stretch(-1, 1, about_point=ORIGIN)
|
||||
|
||||
def get_commands_and_coord_strings(self):
|
||||
all_commands = list(self.get_command_to_function_map().keys())
|
||||
all_commands += [c.lower() for c in all_commands]
|
||||
pattern = "[{}]".format("".join(all_commands))
|
||||
return zip(
|
||||
re.findall(pattern, self.path_string),
|
||||
re.split(pattern, self.path_string)[1:]
|
||||
))
|
||||
# Which mobject should new points be added to
|
||||
self = self
|
||||
for command, coord_string in pairs:
|
||||
self.handle_command(command, coord_string)
|
||||
# people treat y-coordinate differently
|
||||
self.rotate(np.pi, RIGHT, about_point=ORIGIN)
|
||||
)
|
||||
|
||||
def handle_command(self, command, coord_string):
|
||||
isLower = command.islower()
|
||||
command = command.upper()
|
||||
# new_points are the points that will be added to the curr_points
|
||||
# list. This variable may get modified in the conditionals below.
|
||||
points = self.points
|
||||
new_points = self.string_to_points(coord_string)
|
||||
def handle_command(self, command, new_points):
|
||||
if command.islower(): # Treat it as a relative command
|
||||
new_points += self.relative_point
|
||||
|
||||
if isLower and len(points) > 0:
|
||||
new_points += points[-1]
|
||||
func, n_points = self.command_to_function(command)
|
||||
func(*new_points[:n_points])
|
||||
leftover_points = new_points[n_points:]
|
||||
|
||||
if command == "M": # moveto
|
||||
self.start_new_path(new_points[0])
|
||||
if len(new_points) <= 1:
|
||||
return
|
||||
# Recursively handle the rest of the points
|
||||
if len(leftover_points) > 0:
|
||||
if command.upper() == "M":
|
||||
command = "l" # Treat following points as relative line coordinates
|
||||
self.handle_command(command, leftover_points)
|
||||
else:
|
||||
# Command is over, reset for future relative commands
|
||||
self.relative_point = self.points[-1]
|
||||
|
||||
# Draw relative line-to values.
|
||||
points = self.points
|
||||
new_points = new_points[1:]
|
||||
command = "L"
|
||||
|
||||
for p in new_points:
|
||||
if isLower:
|
||||
# Treat everything as relative line-to until empty
|
||||
p[0] += self.points[-1, 0]
|
||||
p[1] += self.points[-1, 1]
|
||||
self.add_line_to(p)
|
||||
return
|
||||
|
||||
elif command in ["L", "H", "V"]: # lineto
|
||||
if command == "H":
|
||||
new_points[0, 1] = points[-1, 1]
|
||||
elif command == "V":
|
||||
if isLower:
|
||||
new_points[0, 0] -= points[-1, 0]
|
||||
new_points[0, 0] += points[-1, 1]
|
||||
new_points[0, 1] = new_points[0, 0]
|
||||
new_points[0, 0] = points[-1, 0]
|
||||
self.add_line_to(new_points[0])
|
||||
return
|
||||
|
||||
if command == "C": # curveto
|
||||
pass # Yay! No action required
|
||||
elif command in ["S", "T"]: # smooth curveto
|
||||
self.add_smooth_curve_to(*new_points)
|
||||
# handle1 = points[-1] + (points[-1] - points[-2])
|
||||
# new_points = np.append([handle1], new_points, axis=0)
|
||||
return
|
||||
elif command == "Q": # quadratic Bezier curve
|
||||
# TODO, this is a suboptimal approximation
|
||||
new_points = np.append([new_points[0]], new_points, axis=0)
|
||||
elif command == "A": # elliptical Arc
|
||||
raise Exception("Not implemented")
|
||||
elif command == "Z": # closepath
|
||||
return
|
||||
|
||||
# Add first three points
|
||||
self.add_cubic_bezier_curve_to(*new_points[0:3])
|
||||
|
||||
# Handle situations where there's multiple relative control points
|
||||
if len(new_points) > 3:
|
||||
# Add subsequent offset points relatively.
|
||||
for i in range(3, len(new_points), 3):
|
||||
if isLower:
|
||||
new_points[i:i + 3] -= points[-1]
|
||||
new_points[i:i + 3] += new_points[i - 1]
|
||||
self.add_cubic_bezier_curve_to(*new_points[i:i+3])
|
||||
|
||||
def string_to_points(self, coord_string):
|
||||
def string_to_points(self, command, coord_string):
|
||||
numbers = string_to_numbers(coord_string)
|
||||
if len(numbers) % 2 == 1:
|
||||
numbers.append(0)
|
||||
num_points = len(numbers) // 2
|
||||
result = np.zeros((num_points, self.dim))
|
||||
result[:, :2] = np.array(numbers).reshape((num_points, 2))
|
||||
if command.upper() in ["H", "V"]:
|
||||
i = {"H": 0, "V": 1}[command.upper()]
|
||||
xy = np.zeros((len(numbers), 2))
|
||||
xy[:, i] = numbers
|
||||
if command.isupper():
|
||||
xy[:, 1 - i] = self.relative_point[1 - i]
|
||||
elif command.upper() == "A":
|
||||
raise Exception("Not implemented")
|
||||
else:
|
||||
xy = np.array(numbers).reshape((len(numbers) // 2, 2))
|
||||
result = np.zeros((xy.shape[0], self.dim))
|
||||
result[:, :2] = xy
|
||||
return result
|
||||
|
||||
def command_to_function(self, command):
|
||||
return self.get_command_to_function_map()[command.upper()]
|
||||
|
||||
def get_command_to_function_map(self):
|
||||
"""
|
||||
Associates svg command to VMobject function, and
|
||||
the number of arguments it takes in
|
||||
"""
|
||||
return {
|
||||
"M": (self.start_new_path, 1),
|
||||
"L": (self.add_line_to, 1),
|
||||
"H": (self.add_line_to, 1),
|
||||
"V": (self.add_line_to, 1),
|
||||
"C": (self.add_cubic_bezier_curve_to, 3),
|
||||
"S": (self.add_smooth_cubic_curve_to, 2),
|
||||
"Q": (self.add_quadratic_bezier_curve_to, 2),
|
||||
"T": (self.add_smooth_curve_to, 1),
|
||||
"A": (self.add_quadratic_bezier_curve_to, 2), # TODO
|
||||
"Z": (lambda: self.add_line_to(self.points[0]), 0),
|
||||
}
|
||||
|
||||
def get_original_path_string(self):
|
||||
return self.path_string
|
||||
|
Reference in New Issue
Block a user