Refactor of capture_mobjects to be cleaner about batching

This commit is contained in:
Grant Sanderson
2018-02-16 10:57:04 -08:00
parent df6aae36a6
commit 4236215da7

View File

@ -177,29 +177,30 @@ class Camera(object):
def capture_mobjects(self, mobjects, **kwargs): def capture_mobjects(self, mobjects, **kwargs):
mobjects = self.get_mobjects_to_display(mobjects, **kwargs) mobjects = self.get_mobjects_to_display(mobjects, **kwargs)
vmobjects = []
for mobject in mobjects: # Organize this list into batches of the same type, and
if isinstance(mobject, VMobject): # apply corresponding function to those batches
vmobjects.append(mobject) type_func_pairs = [
elif len(vmobjects) > 0: (VMobject, self.display_multiple_vectorized_mobjects),
self.display_multiple_vectorized_mobjects(vmobjects) (PMobject, self.display_multiple_point_cloud_mobjects),
vmobjects = [] (ImageMobject, self.display_multiple_image_mobjects),
(Mobject, lambda batch : batch), #Do nothing
if isinstance(mobject, PMobject): ]
self.display_point_cloud( def get_mobject_type(mobject):
mobject.points, mobject.rgbas, for mobject_type, func in type_func_pairs:
self.adjusted_thickness(mobject.stroke_width) if isinstance(mobject, mobject_type):
) return mobject_type
elif isinstance(mobject, ImageMobject): raise Exception(
self.display_image_mobject(mobject) "Trying to display something which is not of type Mobject"
elif isinstance(mobject, Mobject): )
pass #Remainder of loop will handle submobjects batches = batch_by_property(mobjects, get_mobject_type)
else:
raise Exception( #Display in these batches
"Unknown mobject type: " + mobject.__class__.__name__ for batch in batches:
) #check what the type is, and call the appropriate function
#TODO, more? Call out if it's unknown? for mobject_type, func in type_func_pairs:
self.display_multiple_vectorized_mobjects(vmobjects) if isinstance(batch[0], mobject_type):
func(batch)
## Methods associated with svg rendering ## Methods associated with svg rendering
@ -312,6 +313,14 @@ class Camera(object):
## Methods for other rendering ## Methods for other rendering
def display_multiple_point_cloud_mobjects(self, pmobjects):
for pmobject in pmobjects:
self.display_point_cloud(
pmobject.points,
pmobject.rgbas,
self.adjusted_thickness(pmobject.stroke_width)
)
def display_point_cloud(self, points, rgbas, thickness): def display_point_cloud(self, points, rgbas, thickness):
if len(points) == 0: if len(points) == 0:
return return
@ -342,6 +351,10 @@ class Camera(object):
new_pa[indices] = rgbas new_pa[indices] = rgbas
self.pixel_array = new_pa.reshape((ph, pw, rgba_len)) self.pixel_array = new_pa.reshape((ph, pw, rgba_len))
def display_multiple_image_mobjects(self, image_mobjects):
for image_mobject in image_mobjects:
self.display_image_mobject(image_mobject)
def display_image_mobject(self, image_mobject): def display_image_mobject(self, image_mobject):
corner_coords = self.points_to_pixel_coords(image_mobject.points) corner_coords = self.points_to_pixel_coords(image_mobject.points)
ul_coords, ur_coords, dl_coords = corner_coords ul_coords, ur_coords, dl_coords = corner_coords