Refactor of capture_mobjects to be cleaner about batching

This commit is contained in:
Grant Sanderson
2018-02-16 10:57:04 -08:00
parent df6aae36a6
commit 4236215da7

View File

@ -177,29 +177,30 @@ class Camera(object):
def capture_mobjects(self, mobjects, **kwargs):
mobjects = self.get_mobjects_to_display(mobjects, **kwargs)
vmobjects = []
for mobject in mobjects:
if isinstance(mobject, VMobject):
vmobjects.append(mobject)
elif len(vmobjects) > 0:
self.display_multiple_vectorized_mobjects(vmobjects)
vmobjects = []
if isinstance(mobject, PMobject):
self.display_point_cloud(
mobject.points, mobject.rgbas,
self.adjusted_thickness(mobject.stroke_width)
)
elif isinstance(mobject, ImageMobject):
self.display_image_mobject(mobject)
elif isinstance(mobject, Mobject):
pass #Remainder of loop will handle submobjects
else:
raise Exception(
"Unknown mobject type: " + mobject.__class__.__name__
)
#TODO, more? Call out if it's unknown?
self.display_multiple_vectorized_mobjects(vmobjects)
# Organize this list into batches of the same type, and
# apply corresponding function to those batches
type_func_pairs = [
(VMobject, self.display_multiple_vectorized_mobjects),
(PMobject, self.display_multiple_point_cloud_mobjects),
(ImageMobject, self.display_multiple_image_mobjects),
(Mobject, lambda batch : batch), #Do nothing
]
def get_mobject_type(mobject):
for mobject_type, func in type_func_pairs:
if isinstance(mobject, mobject_type):
return mobject_type
raise Exception(
"Trying to display something which is not of type Mobject"
)
batches = batch_by_property(mobjects, get_mobject_type)
#Display in these batches
for batch in batches:
#check what the type is, and call the appropriate function
for mobject_type, func in type_func_pairs:
if isinstance(batch[0], mobject_type):
func(batch)
## Methods associated with svg rendering
@ -312,6 +313,14 @@ class Camera(object):
## Methods for other rendering
def display_multiple_point_cloud_mobjects(self, pmobjects):
for pmobject in pmobjects:
self.display_point_cloud(
pmobject.points,
pmobject.rgbas,
self.adjusted_thickness(pmobject.stroke_width)
)
def display_point_cloud(self, points, rgbas, thickness):
if len(points) == 0:
return
@ -342,6 +351,10 @@ class Camera(object):
new_pa[indices] = rgbas
self.pixel_array = new_pa.reshape((ph, pw, rgba_len))
def display_multiple_image_mobjects(self, image_mobjects):
for image_mobject in image_mobjects:
self.display_image_mobject(image_mobject)
def display_image_mobject(self, image_mobject):
corner_coords = self.points_to_pixel_coords(image_mobject.points)
ul_coords, ur_coords, dl_coords = corner_coords