Add refresh_shader_info_id insetead of having create_shader_info_id called all the time

This commit is contained in:
Grant Sanderson
2020-06-27 00:01:45 -07:00
parent 26ce1d86ab
commit 10c6bfe3ad
3 changed files with 44 additions and 29 deletions

View File

@ -22,7 +22,7 @@ from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose
from manimlib.utils.shaders import create_shader_info_id
from manimlib.utils.shaders import refresh_shader_info_id
from manimlib.utils.shaders import get_shader_info
from manimlib.utils.shaders import shader_info_to_id
from manimlib.utils.shaders import is_valid_shader_info
@ -124,7 +124,7 @@ class Mobject(Container):
return self.family
def family_members_with_points(self):
return [m for m in self.get_family() if m.get_num_points() > 0]
return [m for m in self.get_family() if m.points.size > 0]
def add(self, *mobjects):
if self in mobjects:
@ -469,30 +469,6 @@ class Mobject(Container):
# Redundant with default behavior of scale now.
return self.scale(scale_factor, about_point=point)
def fix_in_frame(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.is_fixed_in_frame = True
return self
def unfix_from_frame(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.is_fixed_in_frame = False
return self
def apply_depth_test(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.depth_test = True
return self
def deactivate_depth_test(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.depth_test = False
return self
# Positioning methods
def center(self):
@ -1202,6 +1178,34 @@ class Mobject(Container):
def cleanup_from_animation(self):
pass
# Operations touching shader uniforms
def affects_shader_info_id(func):
def wrapper(self):
for mob in self.get_family():
func(mob)
mob.refresh_shader_info_template_id()
return wrapper
@affects_shader_info_id
def fix_in_frame(self):
self.is_fixed_in_frame = True
return self
@affects_shader_info_id
def unfix_from_frame(self):
self.is_fixed_in_frame = False
return self
@affects_shader_info_id
def apply_depth_test(self):
self.depth_test = True
return self
@affects_shader_info_id
def deactivate_depth_test(self):
self.depth_test = False
return self
# For shaders
def init_shader_data(self):
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
@ -1214,6 +1218,10 @@ class Mobject(Container):
render_primative=self.render_primative,
)
def refresh_shader_info_template_id(self):
refresh_shader_info_id(self.shader_info_template)
return self
def get_blank_shader_data_array(self, size, name="shader_data"):
# If possible, try to populate an existing array, rather
# than recreating it each frame
@ -1245,7 +1253,6 @@ class Mobject(Container):
shader_info["raw_data"] = data.tobytes()
shader_info["attributes"] = data.dtype.names
shader_info["uniforms"] = self.get_shader_uniforms()
shader_info["id"] = create_shader_info_id(shader_info)
return shader_info
def get_shader_uniforms(self):

View File

@ -26,7 +26,7 @@ from manimlib.utils.space_ops import earclip_triangulation
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import get_unit_normal
from manimlib.utils.space_ops import z_to_vector
from manimlib.utils.shaders import create_shader_info_id
from manimlib.utils.shaders import refresh_shader_info_id
from manimlib.utils.shaders import get_shader_info
@ -864,6 +864,11 @@ class VMobject(Mobject):
render_primative=self.render_primative,
)
def refresh_shader_info_template_id(self):
for template in [self.fill_shader_info_template, self.stroke_shader_info_template]:
refresh_shader_info_id(template)
return self
def get_shader_info_list(self):
fill_info = dict(self.fill_shader_info_template)
stroke_info = dict(self.stroke_shader_info_template)
@ -871,7 +876,6 @@ class VMobject(Mobject):
stroke_info["uniforms"] = self.get_stroke_uniforms()
for info in fill_info, stroke_info:
info["depth_test"] = self.depth_test
info["id"] = create_shader_info_id(info)
back_stroke_data = []
stroke_data = []

View File

@ -91,6 +91,10 @@ def create_shader_info_id(shader_info):
return "|".join([str(shader_info[key]) for key in SHADER_KEYS_FOR_ID])
def refresh_shader_info_id(shader_info):
shader_info["id"] = create_shader_info_id(shader_info)
def shader_info_program_id(shader_info):
return "|".join([str(shader_info[key]) for key in ["vert", "geom", "frag"]])