Refactor away from treating shader_info as a dictionary, and make it a proper type as ShaderWrapper. This also includes some cleanup in hos Camera renders

This commit is contained in:
Grant Sanderson
2020-06-29 18:17:18 -07:00
parent 165bf2fe6e
commit 2671817ae9
5 changed files with 224 additions and 256 deletions

View File

@ -22,10 +22,7 @@ from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose
from manimlib.utils.shaders import refresh_shader_info_id
from manimlib.utils.shaders import get_shader_info
from manimlib.utils.shaders import shader_info_to_id
from manimlib.utils.shaders import is_valid_shader_info
from manimlib.utils.shaders import ShaderWrapper
# TODO: Explain array_attrs
@ -217,6 +214,8 @@ class Mobject(Container):
setattr(copy_mobject, attr, value.copy())
if isinstance(value, np.ndarray):
setattr(copy_mobject, attr, np.array(value))
if isinstance(value, ShaderWrapper):
setattr(copy_mobject, attr, value.copy())
return copy_mobject
def deepcopy(self):
@ -235,6 +234,7 @@ class Mobject(Container):
return self.target
# Updating
def init_updaters(self):
self.time_based_updaters = []
self.non_time_updaters = []
@ -1194,7 +1194,7 @@ class Mobject(Container):
def wrapper(self):
for mob in self.get_family():
func(mob)
mob.refresh_shader_info_template_id()
mob.refresh_shader_wrapper_id()
return wrapper
@affects_shader_info_id
@ -1221,8 +1221,8 @@ class Mobject(Container):
def init_shader_data(self):
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
self.shader_indices = None
self.shader_info_template = get_shader_info(
attributes=self.shader_data.dtype.names,
self.shader_wrapper = ShaderWrapper(
vert_data=self.shader_data,
vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file,
@ -1231,8 +1231,8 @@ class Mobject(Container):
render_primative=self.render_primative,
)
def refresh_shader_info_template_id(self):
refresh_shader_info_id(self.shader_info_template)
def refresh_shader_wrapper_id(self):
self.shader_wrapper.refresh_id()
return self
def get_blank_shader_data_array(self, size, name="shader_data"):
@ -1245,41 +1245,30 @@ class Mobject(Container):
return new_arr
return arr
def get_shader_info_list(self):
shader_infos = it.chain(
[self.get_shader_info()],
*[sm.get_shader_info_list() for sm in self.submobjects]
def get_shader_wrapper(self):
self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.uniforms = self.get_shader_uniforms()
self.shader_wrapper.depth_test = self.depth_test
return self.shader_wrapper
def get_shader_wrapper_list(self):
shader_wrappers = it.chain(
[self.get_shader_wrapper()],
*[sm.get_shader_wrapper_list() for sm in self.submobjects]
)
batches = batch_by_property(shader_infos, shader_info_to_id)
batches = batch_by_property(shader_wrappers, lambda sw: sw.get_id())
result = []
for info_group, sid in batches:
combined_info = info_group[0]
if not is_valid_shader_info(combined_info):
for wrapper_group, sid in batches:
shader_wrapper = wrapper_group[0]
if not shader_wrapper.is_valid():
continue
data_list = []
indices_list = []
num_verts = 0
for info in info_group:
data_list.append(info["vert_data"])
if info["vert_indices"] is not None:
indices_list.append(info["vert_indices"] + num_verts)
num_verts += len(info["vert_data"])
# Combine lists
combined_info["vert_data"] = np.hstack(data_list)
if combined_info["vert_indices"] is not None:
combined_info["vert_indices"] = np.hstack(indices_list)
if len(combined_info["vert_indices"]) > 0:
result.append(combined_info)
shader_wrapper.combine_with(*wrapper_group[1:])
if len(shader_wrapper.vert_data) > 0:
result.append(shader_wrapper)
return result
def get_shader_info(self):
shader_info = dict(self.shader_info_template)
shader_info["vert_data"] = self.get_shader_data()
shader_info["vert_indices"] = self.get_shader_vert_indices()
shader_info["uniforms"] = self.get_shader_uniforms()
return shader_info
def get_shader_uniforms(self):
return {
"is_fixed_in_frame": float(self.is_fixed_in_frame),