Refactor svg reading

This commit is contained in:
Grant Sanderson
2020-02-06 10:02:42 -08:00
parent 8c07fcca24
commit ccef2485b2

View File

@ -34,9 +34,9 @@ class SVGMobject(VMobject):
# Must be filled in in a subclass, or when called
"file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups
# TODO, style components should be read in, not defaulted
"stroke_width": DEFAULT_STROKE_WIDTH,
"fill_opacity": 1.0,
# "fill_color" : LIGHT_GREY,
}
def __init__(self, file_name=None, **kwargs):
@ -47,24 +47,25 @@ class SVGMobject(VMobject):
self.move_into_position()
def ensure_valid_file(self):
if self.file_name is None:
file_name = self.file_name
if file_name is None:
raise Exception("Must specify file for SVGMobject")
possible_paths = [
os.path.join(os.path.join("assets", "svg_images"), self.file_name),
os.path.join(os.path.join("assets", "svg_images"), self.file_name + ".svg"),
os.path.join(os.path.join("assets", "svg_images"), self.file_name + ".xdv"),
self.file_name,
os.path.join(os.path.join("assets", "svg_images"), file_name),
os.path.join(os.path.join("assets", "svg_images"), file_name + ".svg"),
os.path.join(os.path.join("assets", "svg_images"), file_name + ".xdv"),
file_name,
]
for path in possible_paths:
if os.path.exists(path):
self.file_path = path
return
raise IOError("No file matching %s in image directory" %
self.file_name)
raise IOError(f"No file matching {file_name} in image directory")
def generate_points(self):
doc = minidom.parse(self.file_path)
self.ref_to_element = {}
for svg in doc.getElementsByTagName("svg"):
mobjects = self.get_mobjects_from(svg)
if self.unpack_groups:
@ -122,7 +123,7 @@ class SVGMobject(VMobject):
# Remove initial "#" character
ref = use_element.getAttribute("xlink:href")[1:]
if ref not in self.ref_to_element:
warnings.warn("%s not recognized" % ref)
warnings.warn(f"{ref} not recognized")
return VGroup()
return self.get_mobjects_from(
self.ref_to_element[ref]
@ -136,15 +137,12 @@ class SVGMobject(VMobject):
return float(stripped_attr)
def polygon_to_mobject(self, polygon_element):
# TODO, This seems hacky...
path_string = polygon_element.getAttribute("points")
for digit in string.digits:
path_string = path_string.replace(" " + digit, " L" + digit)
path_string = path_string.replace(f" {digit}", f"{digit} L")
path_string = "M" + path_string
return self.path_string_to_mobject(path_string)
# <circle class="st1" cx="143.8" cy="268" r="22.6"/>
def circle_to_mobject(self, circle_element):
x, y, r = [
self.attribute_to_float(
@ -321,111 +319,76 @@ class VMobjectFromSVGPathstring(VMobject):
digest_locals(self)
VMobject.__init__(self, **kwargs)
def get_path_commands(self):
result = [
"M", # moveto
"L", # lineto
"H", # horizontal lineto
"V", # vertical lineto
"C", # curveto
"S", # smooth curveto
"Q", # quadratic Bezier curve
"T", # smooth quadratic Bezier curveto
"A", # elliptical Arc
"Z", # closepath
]
result += [s.lower() for s in result]
return result
def generate_points(self):
pattern = "[%s]" % ("".join(self.get_path_commands()))
pairs = list(zip(
self.relative_point = ORIGIN
for command, coord_string in self.get_commands_and_coord_strings():
new_points = self.string_to_points(command, coord_string)
self.handle_command(command, new_points)
# SVG treats y-coordinate differently
self.stretch(-1, 1, about_point=ORIGIN)
def get_commands_and_coord_strings(self):
all_commands = list(self.get_command_to_function_map().keys())
all_commands += [c.lower() for c in all_commands]
pattern = "[{}]".format("".join(all_commands))
return zip(
re.findall(pattern, self.path_string),
re.split(pattern, self.path_string)[1:]
))
# Which mobject should new points be added to
self = self
for command, coord_string in pairs:
self.handle_command(command, coord_string)
# people treat y-coordinate differently
self.rotate(np.pi, RIGHT, about_point=ORIGIN)
)
def handle_command(self, command, coord_string):
isLower = command.islower()
command = command.upper()
# new_points are the points that will be added to the curr_points
# list. This variable may get modified in the conditionals below.
points = self.points
new_points = self.string_to_points(coord_string)
def handle_command(self, command, new_points):
if command.islower(): # Treat it as a relative command
new_points += self.relative_point
if isLower and len(points) > 0:
new_points += points[-1]
func, n_points = self.command_to_function(command)
func(*new_points[:n_points])
leftover_points = new_points[n_points:]
if command == "M": # moveto
self.start_new_path(new_points[0])
if len(new_points) <= 1:
return
# Recursively handle the rest of the points
if len(leftover_points) > 0:
if command.upper() == "M":
command = "l" # Treat following points as relative line coordinates
self.handle_command(command, leftover_points)
else:
# Command is over, reset for future relative commands
self.relative_point = self.points[-1]
# Draw relative line-to values.
points = self.points
new_points = new_points[1:]
command = "L"
for p in new_points:
if isLower:
# Treat everything as relative line-to until empty
p[0] += self.points[-1, 0]
p[1] += self.points[-1, 1]
self.add_line_to(p)
return
elif command in ["L", "H", "V"]: # lineto
if command == "H":
new_points[0, 1] = points[-1, 1]
elif command == "V":
if isLower:
new_points[0, 0] -= points[-1, 0]
new_points[0, 0] += points[-1, 1]
new_points[0, 1] = new_points[0, 0]
new_points[0, 0] = points[-1, 0]
self.add_line_to(new_points[0])
return
if command == "C": # curveto
pass # Yay! No action required
elif command in ["S", "T"]: # smooth curveto
self.add_smooth_curve_to(*new_points)
# handle1 = points[-1] + (points[-1] - points[-2])
# new_points = np.append([handle1], new_points, axis=0)
return
elif command == "Q": # quadratic Bezier curve
# TODO, this is a suboptimal approximation
new_points = np.append([new_points[0]], new_points, axis=0)
elif command == "A": # elliptical Arc
raise Exception("Not implemented")
elif command == "Z": # closepath
return
# Add first three points
self.add_cubic_bezier_curve_to(*new_points[0:3])
# Handle situations where there's multiple relative control points
if len(new_points) > 3:
# Add subsequent offset points relatively.
for i in range(3, len(new_points), 3):
if isLower:
new_points[i:i + 3] -= points[-1]
new_points[i:i + 3] += new_points[i - 1]
self.add_cubic_bezier_curve_to(*new_points[i:i+3])
def string_to_points(self, coord_string):
def string_to_points(self, command, coord_string):
numbers = string_to_numbers(coord_string)
if len(numbers) % 2 == 1:
numbers.append(0)
num_points = len(numbers) // 2
result = np.zeros((num_points, self.dim))
result[:, :2] = np.array(numbers).reshape((num_points, 2))
if command.upper() in ["H", "V"]:
i = {"H": 0, "V": 1}[command.upper()]
xy = np.zeros((len(numbers), 2))
xy[:, i] = numbers
if command.isupper():
xy[:, 1 - i] = self.relative_point[1 - i]
elif command.upper() == "A":
raise Exception("Not implemented")
else:
xy = np.array(numbers).reshape((len(numbers) // 2, 2))
result = np.zeros((xy.shape[0], self.dim))
result[:, :2] = xy
return result
def command_to_function(self, command):
return self.get_command_to_function_map()[command.upper()]
def get_command_to_function_map(self):
"""
Associates svg command to VMobject function, and
the number of arguments it takes in
"""
return {
"M": (self.start_new_path, 1),
"L": (self.add_line_to, 1),
"H": (self.add_line_to, 1),
"V": (self.add_line_to, 1),
"C": (self.add_cubic_bezier_curve_to, 3),
"S": (self.add_smooth_cubic_curve_to, 2),
"Q": (self.add_quadratic_bezier_curve_to, 2),
"T": (self.add_smooth_curve_to, 1),
"A": (self.add_quadratic_bezier_curve_to, 2), # TODO
"Z": (lambda: self.add_line_to(self.points[0]), 0),
}
def get_original_path_string(self):
return self.path_string