Don't distinguish stroke uniforms from fill uniforms

This commit is contained in:
Grant Sanderson
2022-12-28 19:17:52 -08:00
parent 8fc243e398
commit a92a506224

View File

@ -122,6 +122,8 @@ class VMobject(Mobject):
def init_uniforms(self):
super().init_uniforms()
self.uniforms["anti_alias_width"] = self.anti_alias_width
self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
self.uniforms["flat_stroke"] = float(self.flat_stroke)
# These are here just to make type checkers happy
def get_family(self, recurse: bool = True) -> list[VMobject]:
@ -396,19 +398,19 @@ class VMobject(Mobject):
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True):
for mob in self.get_family(recurse):
mob.flat_stroke = flat_stroke
mob.uniforms["flat_stroke"] = float(flat_stroke)
return self
def get_flat_stroke(self) -> bool:
return self.flat_stroke
return self.uniforms["flat_stroke"] == 1.0
def set_joint_type(self, joint_type: str, recurse: bool = True):
for mob in self.get_family(recurse):
mob.joint_type = joint_type
mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type]
return self
def get_joint_type(self) -> str:
return self.joint_type
def get_joint_type(self) -> float:
return self.uniforms["joint_type"]
# Points
def set_anchors_and_handles(
@ -1066,7 +1068,7 @@ class VMobject(Mobject):
)
self.stroke_shader_wrapper = ShaderWrapper(
vert_data=self.stroke_data,
uniforms=self.get_stroke_uniforms(),
uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder,
render_primitive=self.render_primitive,
)
@ -1089,7 +1091,7 @@ class VMobject(Mobject):
def get_stroke_shader_wrapper(self) -> ShaderWrapper:
self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data()
self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms()
self.stroke_shader_wrapper.uniforms = self.get_shader_uniforms()
self.stroke_shader_wrapper.depth_test = self.depth_test
return self.stroke_shader_wrapper
@ -1116,13 +1118,13 @@ class VMobject(Mobject):
]
for i, sw in enumerate(result):
sw.depth_test = self.depth_test
return list(filter(lambda sw: len(sw.vert_data) > 0, result))
def get_stroke_uniforms(self) -> dict[str, float]:
result = dict(super().get_shader_uniforms())
result["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
result["flat_stroke"] = float(self.flat_stroke)
return result
sw.uniforms = self.uniforms
return list(filter(lambda sw: len(sw.vert_data) > 0, self.shader_wrapper_list))
def get_stroke_shader_data(self) -> np.ndarray:
points = self.get_points()