Initial implementation of render groups in Scene

This commit is contained in:
Grant Sanderson
2023-01-28 10:11:10 -08:00
parent fc379dab18
commit 8a18967ea4

View File

@ -7,6 +7,7 @@ import platform
import pyperclip
import random
import time
from functools import wraps
from IPython.terminal import pt_inputhooks
from IPython.terminal.embed import InteractiveShellEmbed
@ -37,6 +38,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.scene.scene_file_writer import SceneFileWriter
from manimlib.utils.family_ops import extract_mobject_family_members
from manimlib.utils.family_ops import recursive_mobject_remove
from manimlib.utils.iterables import batch_by_property
from typing import TYPE_CHECKING
@ -110,6 +112,7 @@ class Scene(object):
self.camera: Camera = Camera(**self.camera_config)
self.file_writer = SceneFileWriter(self, **self.file_writer_config)
self.mobjects: list[Mobject] = [self.camera.frame]
self.render_groups: list[Mobject] = []
self.id_to_mobject_map: dict[int, Mobject] = dict()
self.num_plays: int = 0
self.time: float = 0
@ -289,7 +292,7 @@ class Scene(object):
def get_image(self) -> Image:
if self.window is not None:
self.camera.use_window_fbo(False)
self.camera.capture(*self.mobjects)
self.camera.capture(*self.render_groups)
image = self.camera.get_image()
if self.window is not None:
self.camera.use_window_fbo(True)
@ -310,7 +313,7 @@ class Scene(object):
if self.window:
self.window.clear()
self.camera.capture(*self.mobjects)
self.camera.capture(*self.render_groups)
if self.window:
self.window.swap_buffers()
@ -369,6 +372,34 @@ class Scene(object):
def get_mobject_family_members(self) -> list[Mobject]:
return extract_mobject_family_members(self.mobjects)
def assemble_render_groups(self):
"""
Rendering is more efficient when VMobjects are grouped
together, so this function creates VGroups of all
clusters of adjacent VMobjects in the scene's mobject
list.
"""
for group in self.render_groups:
group.clear()
self.render_groups = []
batches = batch_by_property(
self.mobjects,
lambda m: str(m.get_uniforms()) + str(m.apply_depth_test)
)
self.render_groups = [
batch[0].get_group_class()(*batch)
for batch, key in batches
]
def affects_mobject_list(func: Callable):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self.assemble_render_groups()
return self
return wrapper
@affects_mobject_list
def add(self, *new_mobjects: Mobject):
"""
Mobjects will be displayed, from background to
@ -395,6 +426,7 @@ class Scene(object):
))
return self
@affects_mobject_list
def replace(self, mobject: Mobject, *replacements: Mobject):
if mobject in self.mobjects:
index = self.mobjects.index(mobject)
@ -405,6 +437,7 @@ class Scene(object):
]
return self
@affects_mobject_list
def remove(self, *mobjects_to_remove: Mobject):
"""
Removes anything in mobjects from scenes mobject list, but in the event that one
@ -422,11 +455,13 @@ class Scene(object):
self.add(*mobjects)
return self
@affects_mobject_list
def bring_to_back(self, *mobjects: Mobject):
self.remove(*mobjects)
self.mobjects = list(mobjects) + self.mobjects
return self
@affects_mobject_list
def clear(self):
self.mobjects = []
return self