diff --git a/camera/camera.py b/camera/camera.py index 7c6d5477..0b284e85 100644 --- a/camera/camera.py +++ b/camera/camera.py @@ -177,29 +177,30 @@ class Camera(object): def capture_mobjects(self, mobjects, **kwargs): mobjects = self.get_mobjects_to_display(mobjects, **kwargs) - vmobjects = [] - for mobject in mobjects: - if isinstance(mobject, VMobject): - vmobjects.append(mobject) - elif len(vmobjects) > 0: - self.display_multiple_vectorized_mobjects(vmobjects) - vmobjects = [] - - if isinstance(mobject, PMobject): - self.display_point_cloud( - mobject.points, mobject.rgbas, - self.adjusted_thickness(mobject.stroke_width) - ) - elif isinstance(mobject, ImageMobject): - self.display_image_mobject(mobject) - elif isinstance(mobject, Mobject): - pass #Remainder of loop will handle submobjects - else: - raise Exception( - "Unknown mobject type: " + mobject.__class__.__name__ - ) - #TODO, more? Call out if it's unknown? - self.display_multiple_vectorized_mobjects(vmobjects) + + # Organize this list into batches of the same type, and + # apply corresponding function to those batches + type_func_pairs = [ + (VMobject, self.display_multiple_vectorized_mobjects), + (PMobject, self.display_multiple_point_cloud_mobjects), + (ImageMobject, self.display_multiple_image_mobjects), + (Mobject, lambda batch : batch), #Do nothing + ] + def get_mobject_type(mobject): + for mobject_type, func in type_func_pairs: + if isinstance(mobject, mobject_type): + return mobject_type + raise Exception( + "Trying to display something which is not of type Mobject" + ) + batches = batch_by_property(mobjects, get_mobject_type) + + #Display in these batches + for batch in batches: + #check what the type is, and call the appropriate function + for mobject_type, func in type_func_pairs: + if isinstance(batch[0], mobject_type): + func(batch) ## Methods associated with svg rendering @@ -312,6 +313,14 @@ class Camera(object): ## Methods for other rendering + def display_multiple_point_cloud_mobjects(self, pmobjects): + for pmobject in pmobjects: + self.display_point_cloud( + pmobject.points, + pmobject.rgbas, + self.adjusted_thickness(pmobject.stroke_width) + ) + def display_point_cloud(self, points, rgbas, thickness): if len(points) == 0: return @@ -342,6 +351,10 @@ class Camera(object): new_pa[indices] = rgbas self.pixel_array = new_pa.reshape((ph, pw, rgba_len)) + def display_multiple_image_mobjects(self, image_mobjects): + for image_mobject in image_mobjects: + self.display_image_mobject(image_mobject) + def display_image_mobject(self, image_mobject): corner_coords = self.points_to_pixel_coords(image_mobject.points) ul_coords, ur_coords, dl_coords = corner_coords