From 9d772496ddcf082ed13b38a0cb1ca9ce28e460cc Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sun, 14 Jun 2020 19:01:04 -0700 Subject: [PATCH] Removes the need to be able to recover shader_info from shader_id --- manimlib/camera/camera.py | 44 ++++++++++---------- manimlib/mobject/mobject.py | 5 +-- manimlib/mobject/types/vectorized_mobject.py | 6 +-- manimlib/utils/shaders.py | 34 +++++++-------- 4 files changed, 41 insertions(+), 48 deletions(-) diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index c440ba0a..3f099784 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -14,8 +14,7 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.iterables import batch_by_property from manimlib.utils.simple_functions import fdiv from manimlib.utils.shaders import shader_info_to_id -from manimlib.utils.shaders import shader_id_to_info -from manimlib.utils.shaders import get_shader_code_from_file +from manimlib.utils.shaders import shader_info_to_program_code from manimlib.utils.simple_functions import clip from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import rotation_matrix_transpose_from_quaternion @@ -355,38 +354,37 @@ class Camera(object): def get_shader(self, shader_info): sid = shader_info_to_id(shader_info) if sid not in self.id_to_shader: - info = shader_id_to_info(sid) - shader = self.ctx.program( - vertex_shader=get_shader_code_from_file(info["vert"]), - geometry_shader=get_shader_code_from_file(info["geom"]), - fragment_shader=get_shader_code_from_file(info["frag"]), - ) - if info["texture_paths"]: - for name, path in info["texture_paths"].items(): - tid = self.get_texture_id(path) - shader[name].value = tid - self.set_shader_uniforms(shader, sid) + # Create shader program for the first time, then cache + # in the id_to_shader dictionary + shader = self.ctx.program(**shader_info_to_program_code(sid)) + self.set_shader_uniforms(shader) + for name, path in shader_info["texture_paths"].items(): + tid = self.get_texture_id(path) + shader[name].value = tid + for name, value in shader_info["uniforms"].items(): + shader[name].value = value self.id_to_shader[sid] = shader return self.id_to_shader[sid] - def set_shader_uniforms(self, shader, sid): + def set_shader_uniforms(self, shader): if shader is None: return pw, ph = self.get_pixel_shape() fw, fh = self.frame.get_shape() + # TODO, this should probably be a mobject uniform, with + # the camera taking care of the conversion factor anti_alias_width = self.anti_alias_width / (ph / fh) transform = self.frame.get_inverse_camera_position_matrix() light = self.light_source.get_location() transformed_light = np.dot(transform, [*light, 1])[:3] - mapping = dict() - mapping['to_screen_space'] = tuple(transform.T.flatten()) - mapping['frame_shape'] = self.frame.get_shape() - mapping['focal_distance'] = self.frame.get_focal_distance() - mapping['anti_alias_width'] = anti_alias_width - mapping['light_source_position'] = tuple(transformed_light) - # Potentially overwrite with whatever came from the mobject - mapping.update(shader_id_to_info(sid)["uniforms"]) + mapping = { + 'to_screen_space': tuple(transform.T.flatten()), + 'frame_shape': self.frame.get_shape(), + 'focal_distance': self.frame.get_focal_distance(), + 'anti_alias_width': anti_alias_width, + 'light_source_position': tuple(transformed_light), + } for key, value in mapping.items(): try: @@ -396,7 +394,7 @@ class Camera(object): def refresh_shader_uniforms(self): for sid, shader in self.id_to_shader.items(): - self.set_shader_uniforms(shader, sid) + self.set_shader_uniforms(shader) def init_textures(self): self.path_to_texture_id = {} diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 3a98d483..981cfbea 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -24,7 +24,6 @@ from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix_transpose from manimlib.utils.shaders import get_shader_info from manimlib.utils.shaders import shader_info_to_id -from manimlib.utils.shaders import shader_id_to_info from manimlib.utils.shaders import is_valid_shader_info @@ -1229,7 +1228,7 @@ class Mobject(Container): result = [] for info_group, sid in batches: - shader_info = shader_id_to_info(sid) + shader_info = info_group[0] shader_info["data"] = np.hstack([info["data"] for info in info_group]) if is_valid_shader_info(shader_info): result.append(shader_info) @@ -1238,10 +1237,10 @@ class Mobject(Container): def get_shader_info(self): return get_shader_info( data=self.get_shader_data(), - uniforms=self.get_shader_uniforms(), vert_file=self.vert_shader_file, geom_file=self.geom_shader_file, frag_file=self.frag_shader_file, + uniforms=self.get_shader_uniforms(), texture_paths=self.texture_paths, depth_test=self.depth_test, render_primative=self.render_primative, diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index b5ef51df..e0807813 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -848,18 +848,18 @@ class VMobject(Mobject): return self.saved_shader_info_list stroke_info = get_shader_info( - uniforms=self.get_stroke_uniforms(), vert_file=self.stroke_vert_shader_file, geom_file=self.stroke_geom_shader_file, frag_file=self.stroke_frag_shader_file, + uniforms=self.get_stroke_uniforms(), depth_test=self.depth_test, render_primative=self.render_primative, ) fill_info = get_shader_info( - uniforms=self.get_shader_uniforms(), vert_file=self.fill_vert_shader_file, geom_file=self.fill_geom_shader_file, frag_file=self.fill_frag_shader_file, + uniforms=self.get_shader_uniforms(), depth_test=self.depth_test, render_primative=self.render_primative, ) @@ -904,7 +904,7 @@ class VMobject(Mobject): "miter": 3, } result = super().get_shader_uniforms() - result["join_type"] = j_map[self.joint_type] + result["joint_type"] = j_map[self.joint_type] return result def get_stroke_shader_data(self): diff --git a/manimlib/utils/shaders.py b/manimlib/utils/shaders.py index 946c569f..8a656069 100644 --- a/manimlib/utils/shaders.py +++ b/manimlib/utils/shaders.py @@ -2,7 +2,6 @@ import os import warnings import re import moderngl -import json from manimlib.constants import SHADER_DIR @@ -17,14 +16,14 @@ SHADER_INFO_KEYS = [ # A structred array caring all of the points/color/lighting/etc. information # needed for the shader. "data", - # A dictionary mapping names of uniform variables - "uniforms", # Filename of vetex shader "vert", # Filename of geometry shader, if there is one "geom", # Filename of fragment shader "frag", + # A dictionary mapping names of uniform variables + "uniforms", # A dictionary mapping names (as they show up in) # the shader to filepaths for textures. "texture_paths", @@ -39,10 +38,10 @@ SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[1:] def get_shader_info(data=None, - uniforms=None, vert_file=None, geom_file=None, frag_file=None, + uniforms=None, texture_paths=None, depth_test=False, render_primative=moderngl.TRIANGLE_STRIP, @@ -53,11 +52,11 @@ def get_shader_info(data=None, SHADER_INFO_KEYS, [ data, - uniforms, vert_file, geom_file, frag_file, - texture_paths or {}, + uniforms or dict(), + texture_paths or dict(), depth_test, str(render_primative) ] @@ -75,19 +74,8 @@ def is_valid_shader_info(shader_info): def shader_info_to_id(shader_info): - # A unique id for a shader based on the - # files holding its code and texture - tuples = [ - (key, shader_info[key]) - for key in SHADER_KEYS_FOR_ID - ] - return json.dumps(tuples) - - -def shader_id_to_info(sid): - result = dict(json.loads(sid)) - result["data"] = None - return result + # A unique id for a shader + return "|".join([str(shader_info[key]) for key in SHADER_KEYS_FOR_ID]) def same_shader_type(info1, info2): @@ -97,6 +85,14 @@ def same_shader_type(info1, info2): ]) +def shader_info_to_program_code(shader_info): + return { + "vertex_shader": get_shader_code_from_file(shader_info["vert"]), + "geometry_shader": get_shader_code_from_file(shader_info["geom"]), + "fragment_shader": get_shader_code_from_file(shader_info["frag"]), + } + + def get_shader_code_from_file(filename): if not filename: return None