From bab809b9a63ce6a36128c76cc6eddcf7f7684501 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 8 Jun 2020 15:57:12 -0700 Subject: [PATCH] Allow for passing shader uniforms from mobjects --- manimlib/camera/camera.py | 3 ++- manimlib/mobject/mobject.py | 4 ++++ manimlib/mobject/types/surface.py | 7 +++++++ manimlib/mobject/types/vectorized_mobject.py | 12 +++++++++++ manimlib/shaders/textured_surface_frag.glsl | 21 ++++++++++++-------- manimlib/utils/shaders.py | 11 ++++++++-- 6 files changed, 47 insertions(+), 11 deletions(-) diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 3e3aaf3d..6e7bdfff 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -371,13 +371,14 @@ class Camera(object): for name, path in info["texture_paths"].items(): tid = self.get_texture_id(path) shader[name].value = tid + for name, value in info["uniforms"].items(): + shader[name].value = value self.set_shader_uniforms(shader) self.id_to_shader[sid] = shader return self.id_to_shader[sid] def set_shader_uniforms(self, shader): - # TODO, think about how uniforms come from mobjects as well. if shader is None: return diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 1209a930..aa3c85ac 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1212,6 +1212,7 @@ class Mobject(Container): def get_shader_info(self): return get_shader_info( data=self.get_shader_data(), + uniforms=self.get_shader_uniforms(), vert_file=self.vert_shader_file, geom_file=self.geom_shader_file, frag_file=self.frag_shader_file, @@ -1219,6 +1220,9 @@ class Mobject(Container): texture_paths=self.texture_paths, ) + def get_shader_uniforms(self): + return {} + def get_shader_data(self): # Typically to be implemented by subclasses # Must return a structured numpy array diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index f90e1ca4..5c4d3304 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -170,8 +170,12 @@ class TexturedSurface(ParametricSurface): def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs): if not isinstance(uv_surface, ParametricSurface): raise Exception("uv_surface must be of type ParametricSurface") + # Set texture information if dark_image_file is None: dark_image_file = image_file + self.num_textures = 1 + else: + self.num_textures = 2 self.texture_paths = { "LightTexture": get_full_raster_image_path(image_file), "DarkTexture": get_full_raster_image_path(dark_image_file), @@ -209,6 +213,9 @@ class TexturedSurface(ParametricSurface): sm.set_opacity(opacity, family) return self + def get_shader_uniforms(self): + return {"num_textures": self.num_textures} + def fill_in_shader_color_info(self, data): data["im_coords"] = self.get_triangle_ready_array(self.im_coords) data["opacity"] = self.opacity diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index bdd4e83f..4903d8f6 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -824,12 +824,14 @@ class VMobject(Mobject): return self.saved_shader_info_list stroke_info = get_shader_info( + uniforms=self.get_stroke_uniforms(), vert_file=self.stroke_vert_shader_file, geom_file=self.stroke_geom_shader_file, frag_file=self.stroke_frag_shader_file, render_primative=self.render_primative, ) fill_info = get_shader_info( + uniforms={}, vert_file=self.fill_vert_shader_file, geom_file=self.fill_geom_shader_file, frag_file=self.fill_frag_shader_file, @@ -868,6 +870,16 @@ class VMobject(Mobject): result.append(stroke_info) return result + def get_stroke_uniforms(self): + joint_type_to_code = { + "auto": 0, + "round": 1, + "bevel": 2, + "miter": 3, + } + # return {"joint_type": joint_type_to_code[self.joint_type]} + return {} # TODO + def get_stroke_shader_data(self): joint_type_to_code = { "auto": 0, diff --git a/manimlib/shaders/textured_surface_frag.glsl b/manimlib/shaders/textured_surface_frag.glsl index 888a34bd..30e8a32d 100644 --- a/manimlib/shaders/textured_surface_frag.glsl +++ b/manimlib/shaders/textured_surface_frag.glsl @@ -2,6 +2,7 @@ uniform sampler2D LightTexture; uniform sampler2D DarkTexture; +uniform float num_textures; uniform vec3 light_source_position; in vec3 xyz_coords; @@ -15,15 +16,19 @@ out vec4 frag_color; #INSERT add_light.glsl +const float dark_shift = 0.2; + void main() { - vec4 light_color = texture(LightTexture, v_im_coords); - vec4 dark_color = texture(DarkTexture, v_im_coords); - float dp = dot( - normalize(light_source_position - xyz_coords), - normalize(v_normal) - ); - float alpha = smoothstep(-0.1, 0.1, dp); - vec4 color = mix(dark_color, light_color, alpha); + vec4 color = texture(LightTexture, v_im_coords); + if(num_textures == 2.0){ + vec4 dark_color = texture(DarkTexture, v_im_coords); + float dp = dot( + normalize(light_source_position - xyz_coords), + normalize(v_normal) + ); + float alpha = smoothstep(-dark_shift, dark_shift, dp); + color = mix(dark_color, color, alpha); + } frag_color = add_light( color, diff --git a/manimlib/utils/shaders.py b/manimlib/utils/shaders.py index 2ebccb79..2715bcc0 100644 --- a/manimlib/utils/shaders.py +++ b/manimlib/utils/shaders.py @@ -17,6 +17,8 @@ SHADER_INFO_KEYS = [ # A structred array caring all of the points/color/lighting/etc. information # needed for the shader. "data", + # A dictionary mapping names of uniform variables + "uniforms", # Filename of vetex shader "vert", # Filename of geometry shader, if there is one @@ -30,8 +32,12 @@ SHADER_INFO_KEYS = [ "render_primative", ] +# Exclude data +SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[1:] + def get_shader_info(data=None, + uniforms=None, vert_file=None, geom_file=None, frag_file=None, @@ -44,6 +50,7 @@ def get_shader_info(data=None, SHADER_INFO_KEYS, [ data, + uniforms, vert_file, geom_file, frag_file, texture_paths or {}, str(render_primative) @@ -66,7 +73,7 @@ def shader_info_to_id(shader_info): # files holding its code and texture tuples = [ (key, shader_info[key]) - for key in SHADER_INFO_KEYS[1:] # Skip data + for key in SHADER_KEYS_FOR_ID ] return json.dumps(tuples) @@ -80,7 +87,7 @@ def shader_id_to_info(sid): def same_shader_type(info1, info2): return all([ info1[key] == info2[key] - for key in SHADER_INFO_KEYS[1:] # Skip data + for key in SHADER_KEYS_FOR_ID ])