Use three shader wrappers to account for backstroke

This commit is contained in:
Grant Sanderson
2023-01-10 16:41:03 -08:00
parent bfaf81c6b3
commit 886fd193f0

View File

@ -1150,6 +1150,7 @@ class VMobject(Mobject):
shader_folder=self.stroke_shader_folder, shader_folder=self.stroke_shader_folder,
render_primitive=self.render_primitive, render_primitive=self.render_primitive,
) )
self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy()
def refresh_shader_wrapper_id(self): def refresh_shader_wrapper_id(self):
for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]: for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]:
@ -1171,18 +1172,26 @@ class VMobject(Mobject):
def get_shader_wrapper_list(self) -> list[ShaderWrapper]: def get_shader_wrapper_list(self) -> list[ShaderWrapper]:
# Build up data lists # Build up data lists
fill_shader_wrappers = [] fill_sws = []
stroke_shader_wrappers = [] stroke_sws = []
bstroke_sws = []
for submob in self.family_members_with_points(): for submob in self.family_members_with_points():
if submob.has_fill(): if submob.has_fill():
fill_shader_wrappers.append(submob.get_fill_shader_wrapper()) fill_sws.append(submob.get_fill_shader_wrapper())
if submob.has_stroke(): if submob.has_stroke():
stroke_shader_wrappers.append(submob.get_stroke_shader_wrapper()) lst = bstroke_sws if submob.draw_stroke_behind_fill else stroke_sws
if submob.draw_stroke_behind_fill: lst.append(submob.get_stroke_shader_wrapper())
self.draw_stroke_behind_fill = True
self_sws = [self.fill_shader_wrapper, self.stroke_shader_wrapper] self_sws = [
sw_lists = [fill_shader_wrappers, stroke_shader_wrappers] self.back_stroke_shader_wrapper,
self.fill_shader_wrapper,
self.stroke_shader_wrapper
]
sw_lists = [
bstroke_sws,
fill_sws,
stroke_sws
]
for sw, sw_list in zip(self_sws, sw_lists): for sw, sw_list in zip(self_sws, sw_lists):
if not sw_list: if not sw_list:
sw.vert_data = resize_array(sw.vert_data, 0) sw.vert_data = resize_array(sw.vert_data, 0)
@ -1193,8 +1202,6 @@ class VMobject(Mobject):
sw.read_in(*sw_list) sw.read_in(*sw_list)
sw.depth_test = any(sw.depth_test for sw in sw_list) sw.depth_test = any(sw.depth_test for sw in sw_list)
sw.uniforms.update(sw_list[0].uniforms) sw.uniforms.update(sw_list[0].uniforms)
if self.draw_stroke_behind_fill:
self_sws.reverse()
return [sw for sw in self_sws if len(sw.vert_data) > 0] return [sw for sw in self_sws if len(sw.vert_data) > 0]
def get_stroke_shader_data(self) -> np.ndarray: def get_stroke_shader_data(self) -> np.ndarray: