Allow for passing shader uniforms from mobjects

This commit is contained in:
Grant Sanderson
2020-06-08 15:57:12 -07:00
parent 39230a805c
commit bab809b9a6
6 changed files with 47 additions and 11 deletions

View File

@ -371,13 +371,14 @@ class Camera(object):
for name, path in info["texture_paths"].items(): for name, path in info["texture_paths"].items():
tid = self.get_texture_id(path) tid = self.get_texture_id(path)
shader[name].value = tid shader[name].value = tid
for name, value in info["uniforms"].items():
shader[name].value = value
self.set_shader_uniforms(shader) self.set_shader_uniforms(shader)
self.id_to_shader[sid] = shader self.id_to_shader[sid] = shader
return self.id_to_shader[sid] return self.id_to_shader[sid]
def set_shader_uniforms(self, shader): def set_shader_uniforms(self, shader):
# TODO, think about how uniforms come from mobjects as well.
if shader is None: if shader is None:
return return

View File

@ -1212,6 +1212,7 @@ class Mobject(Container):
def get_shader_info(self): def get_shader_info(self):
return get_shader_info( return get_shader_info(
data=self.get_shader_data(), data=self.get_shader_data(),
uniforms=self.get_shader_uniforms(),
vert_file=self.vert_shader_file, vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file, geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file, frag_file=self.frag_shader_file,
@ -1219,6 +1220,9 @@ class Mobject(Container):
texture_paths=self.texture_paths, texture_paths=self.texture_paths,
) )
def get_shader_uniforms(self):
return {}
def get_shader_data(self): def get_shader_data(self):
# Typically to be implemented by subclasses # Typically to be implemented by subclasses
# Must return a structured numpy array # Must return a structured numpy array

View File

@ -170,8 +170,12 @@ class TexturedSurface(ParametricSurface):
def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs): def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs):
if not isinstance(uv_surface, ParametricSurface): if not isinstance(uv_surface, ParametricSurface):
raise Exception("uv_surface must be of type ParametricSurface") raise Exception("uv_surface must be of type ParametricSurface")
# Set texture information
if dark_image_file is None: if dark_image_file is None:
dark_image_file = image_file dark_image_file = image_file
self.num_textures = 1
else:
self.num_textures = 2
self.texture_paths = { self.texture_paths = {
"LightTexture": get_full_raster_image_path(image_file), "LightTexture": get_full_raster_image_path(image_file),
"DarkTexture": get_full_raster_image_path(dark_image_file), "DarkTexture": get_full_raster_image_path(dark_image_file),
@ -209,6 +213,9 @@ class TexturedSurface(ParametricSurface):
sm.set_opacity(opacity, family) sm.set_opacity(opacity, family)
return self return self
def get_shader_uniforms(self):
return {"num_textures": self.num_textures}
def fill_in_shader_color_info(self, data): def fill_in_shader_color_info(self, data):
data["im_coords"] = self.get_triangle_ready_array(self.im_coords) data["im_coords"] = self.get_triangle_ready_array(self.im_coords)
data["opacity"] = self.opacity data["opacity"] = self.opacity

View File

@ -824,12 +824,14 @@ class VMobject(Mobject):
return self.saved_shader_info_list return self.saved_shader_info_list
stroke_info = get_shader_info( stroke_info = get_shader_info(
uniforms=self.get_stroke_uniforms(),
vert_file=self.stroke_vert_shader_file, vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file, geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file, frag_file=self.stroke_frag_shader_file,
render_primative=self.render_primative, render_primative=self.render_primative,
) )
fill_info = get_shader_info( fill_info = get_shader_info(
uniforms={},
vert_file=self.fill_vert_shader_file, vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file, geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file, frag_file=self.fill_frag_shader_file,
@ -868,6 +870,16 @@ class VMobject(Mobject):
result.append(stroke_info) result.append(stroke_info)
return result return result
def get_stroke_uniforms(self):
joint_type_to_code = {
"auto": 0,
"round": 1,
"bevel": 2,
"miter": 3,
}
# return {"joint_type": joint_type_to_code[self.joint_type]}
return {} # TODO
def get_stroke_shader_data(self): def get_stroke_shader_data(self):
joint_type_to_code = { joint_type_to_code = {
"auto": 0, "auto": 0,

View File

@ -2,6 +2,7 @@
uniform sampler2D LightTexture; uniform sampler2D LightTexture;
uniform sampler2D DarkTexture; uniform sampler2D DarkTexture;
uniform float num_textures;
uniform vec3 light_source_position; uniform vec3 light_source_position;
in vec3 xyz_coords; in vec3 xyz_coords;
@ -15,15 +16,19 @@ out vec4 frag_color;
#INSERT add_light.glsl #INSERT add_light.glsl
const float dark_shift = 0.2;
void main() { void main() {
vec4 light_color = texture(LightTexture, v_im_coords); vec4 color = texture(LightTexture, v_im_coords);
vec4 dark_color = texture(DarkTexture, v_im_coords); if(num_textures == 2.0){
float dp = dot( vec4 dark_color = texture(DarkTexture, v_im_coords);
normalize(light_source_position - xyz_coords), float dp = dot(
normalize(v_normal) normalize(light_source_position - xyz_coords),
); normalize(v_normal)
float alpha = smoothstep(-0.1, 0.1, dp); );
vec4 color = mix(dark_color, light_color, alpha); float alpha = smoothstep(-dark_shift, dark_shift, dp);
color = mix(dark_color, color, alpha);
}
frag_color = add_light( frag_color = add_light(
color, color,

View File

@ -17,6 +17,8 @@ SHADER_INFO_KEYS = [
# A structred array caring all of the points/color/lighting/etc. information # A structred array caring all of the points/color/lighting/etc. information
# needed for the shader. # needed for the shader.
"data", "data",
# A dictionary mapping names of uniform variables
"uniforms",
# Filename of vetex shader # Filename of vetex shader
"vert", "vert",
# Filename of geometry shader, if there is one # Filename of geometry shader, if there is one
@ -30,8 +32,12 @@ SHADER_INFO_KEYS = [
"render_primative", "render_primative",
] ]
# Exclude data
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[1:]
def get_shader_info(data=None, def get_shader_info(data=None,
uniforms=None,
vert_file=None, vert_file=None,
geom_file=None, geom_file=None,
frag_file=None, frag_file=None,
@ -44,6 +50,7 @@ def get_shader_info(data=None,
SHADER_INFO_KEYS, SHADER_INFO_KEYS,
[ [
data, data,
uniforms,
vert_file, geom_file, frag_file, vert_file, geom_file, frag_file,
texture_paths or {}, texture_paths or {},
str(render_primative) str(render_primative)
@ -66,7 +73,7 @@ def shader_info_to_id(shader_info):
# files holding its code and texture # files holding its code and texture
tuples = [ tuples = [
(key, shader_info[key]) (key, shader_info[key])
for key in SHADER_INFO_KEYS[1:] # Skip data for key in SHADER_KEYS_FOR_ID
] ]
return json.dumps(tuples) return json.dumps(tuples)
@ -80,7 +87,7 @@ def shader_id_to_info(sid):
def same_shader_type(info1, info2): def same_shader_type(info1, info2):
return all([ return all([
info1[key] == info2[key] info1[key] == info2[key]
for key in SHADER_INFO_KEYS[1:] # Skip data for key in SHADER_KEYS_FOR_ID
]) ])