Removes the need to be able to recover shader_info from shader_id

This commit is contained in:
Grant Sanderson
2020-06-14 19:01:04 -07:00
parent 7a152fed1c
commit 9d772496dd
4 changed files with 41 additions and 48 deletions

View File

@ -14,8 +14,7 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import batch_by_property
from manimlib.utils.simple_functions import fdiv from manimlib.utils.simple_functions import fdiv
from manimlib.utils.shaders import shader_info_to_id from manimlib.utils.shaders import shader_info_to_id
from manimlib.utils.shaders import shader_id_to_info from manimlib.utils.shaders import shader_info_to_program_code
from manimlib.utils.shaders import get_shader_code_from_file
from manimlib.utils.simple_functions import clip from manimlib.utils.simple_functions import clip
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import rotation_matrix_transpose_from_quaternion from manimlib.utils.space_ops import rotation_matrix_transpose_from_quaternion
@ -355,38 +354,37 @@ class Camera(object):
def get_shader(self, shader_info): def get_shader(self, shader_info):
sid = shader_info_to_id(shader_info) sid = shader_info_to_id(shader_info)
if sid not in self.id_to_shader: if sid not in self.id_to_shader:
info = shader_id_to_info(sid) # Create shader program for the first time, then cache
shader = self.ctx.program( # in the id_to_shader dictionary
vertex_shader=get_shader_code_from_file(info["vert"]), shader = self.ctx.program(**shader_info_to_program_code(sid))
geometry_shader=get_shader_code_from_file(info["geom"]), self.set_shader_uniforms(shader)
fragment_shader=get_shader_code_from_file(info["frag"]), for name, path in shader_info["texture_paths"].items():
) tid = self.get_texture_id(path)
if info["texture_paths"]: shader[name].value = tid
for name, path in info["texture_paths"].items(): for name, value in shader_info["uniforms"].items():
tid = self.get_texture_id(path) shader[name].value = value
shader[name].value = tid
self.set_shader_uniforms(shader, sid)
self.id_to_shader[sid] = shader self.id_to_shader[sid] = shader
return self.id_to_shader[sid] return self.id_to_shader[sid]
def set_shader_uniforms(self, shader, sid): def set_shader_uniforms(self, shader):
if shader is None: if shader is None:
return return
pw, ph = self.get_pixel_shape() pw, ph = self.get_pixel_shape()
fw, fh = self.frame.get_shape() fw, fh = self.frame.get_shape()
# TODO, this should probably be a mobject uniform, with
# the camera taking care of the conversion factor
anti_alias_width = self.anti_alias_width / (ph / fh) anti_alias_width = self.anti_alias_width / (ph / fh)
transform = self.frame.get_inverse_camera_position_matrix() transform = self.frame.get_inverse_camera_position_matrix()
light = self.light_source.get_location() light = self.light_source.get_location()
transformed_light = np.dot(transform, [*light, 1])[:3] transformed_light = np.dot(transform, [*light, 1])[:3]
mapping = dict() mapping = {
mapping['to_screen_space'] = tuple(transform.T.flatten()) 'to_screen_space': tuple(transform.T.flatten()),
mapping['frame_shape'] = self.frame.get_shape() 'frame_shape': self.frame.get_shape(),
mapping['focal_distance'] = self.frame.get_focal_distance() 'focal_distance': self.frame.get_focal_distance(),
mapping['anti_alias_width'] = anti_alias_width 'anti_alias_width': anti_alias_width,
mapping['light_source_position'] = tuple(transformed_light) 'light_source_position': tuple(transformed_light),
# Potentially overwrite with whatever came from the mobject }
mapping.update(shader_id_to_info(sid)["uniforms"])
for key, value in mapping.items(): for key, value in mapping.items():
try: try:
@ -396,7 +394,7 @@ class Camera(object):
def refresh_shader_uniforms(self): def refresh_shader_uniforms(self):
for sid, shader in self.id_to_shader.items(): for sid, shader in self.id_to_shader.items():
self.set_shader_uniforms(shader, sid) self.set_shader_uniforms(shader)
def init_textures(self): def init_textures(self):
self.path_to_texture_id = {} self.path_to_texture_id = {}

View File

@ -24,7 +24,6 @@ from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose from manimlib.utils.space_ops import rotation_matrix_transpose
from manimlib.utils.shaders import get_shader_info from manimlib.utils.shaders import get_shader_info
from manimlib.utils.shaders import shader_info_to_id from manimlib.utils.shaders import shader_info_to_id
from manimlib.utils.shaders import shader_id_to_info
from manimlib.utils.shaders import is_valid_shader_info from manimlib.utils.shaders import is_valid_shader_info
@ -1229,7 +1228,7 @@ class Mobject(Container):
result = [] result = []
for info_group, sid in batches: for info_group, sid in batches:
shader_info = shader_id_to_info(sid) shader_info = info_group[0]
shader_info["data"] = np.hstack([info["data"] for info in info_group]) shader_info["data"] = np.hstack([info["data"] for info in info_group])
if is_valid_shader_info(shader_info): if is_valid_shader_info(shader_info):
result.append(shader_info) result.append(shader_info)
@ -1238,10 +1237,10 @@ class Mobject(Container):
def get_shader_info(self): def get_shader_info(self):
return get_shader_info( return get_shader_info(
data=self.get_shader_data(), data=self.get_shader_data(),
uniforms=self.get_shader_uniforms(),
vert_file=self.vert_shader_file, vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file, geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file, frag_file=self.frag_shader_file,
uniforms=self.get_shader_uniforms(),
texture_paths=self.texture_paths, texture_paths=self.texture_paths,
depth_test=self.depth_test, depth_test=self.depth_test,
render_primative=self.render_primative, render_primative=self.render_primative,

View File

@ -848,18 +848,18 @@ class VMobject(Mobject):
return self.saved_shader_info_list return self.saved_shader_info_list
stroke_info = get_shader_info( stroke_info = get_shader_info(
uniforms=self.get_stroke_uniforms(),
vert_file=self.stroke_vert_shader_file, vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file, geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file, frag_file=self.stroke_frag_shader_file,
uniforms=self.get_stroke_uniforms(),
depth_test=self.depth_test, depth_test=self.depth_test,
render_primative=self.render_primative, render_primative=self.render_primative,
) )
fill_info = get_shader_info( fill_info = get_shader_info(
uniforms=self.get_shader_uniforms(),
vert_file=self.fill_vert_shader_file, vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file, geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file, frag_file=self.fill_frag_shader_file,
uniforms=self.get_shader_uniforms(),
depth_test=self.depth_test, depth_test=self.depth_test,
render_primative=self.render_primative, render_primative=self.render_primative,
) )
@ -904,7 +904,7 @@ class VMobject(Mobject):
"miter": 3, "miter": 3,
} }
result = super().get_shader_uniforms() result = super().get_shader_uniforms()
result["join_type"] = j_map[self.joint_type] result["joint_type"] = j_map[self.joint_type]
return result return result
def get_stroke_shader_data(self): def get_stroke_shader_data(self):

View File

@ -2,7 +2,6 @@ import os
import warnings import warnings
import re import re
import moderngl import moderngl
import json
from manimlib.constants import SHADER_DIR from manimlib.constants import SHADER_DIR
@ -17,14 +16,14 @@ SHADER_INFO_KEYS = [
# A structred array caring all of the points/color/lighting/etc. information # A structred array caring all of the points/color/lighting/etc. information
# needed for the shader. # needed for the shader.
"data", "data",
# A dictionary mapping names of uniform variables
"uniforms",
# Filename of vetex shader # Filename of vetex shader
"vert", "vert",
# Filename of geometry shader, if there is one # Filename of geometry shader, if there is one
"geom", "geom",
# Filename of fragment shader # Filename of fragment shader
"frag", "frag",
# A dictionary mapping names of uniform variables
"uniforms",
# A dictionary mapping names (as they show up in) # A dictionary mapping names (as they show up in)
# the shader to filepaths for textures. # the shader to filepaths for textures.
"texture_paths", "texture_paths",
@ -39,10 +38,10 @@ SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[1:]
def get_shader_info(data=None, def get_shader_info(data=None,
uniforms=None,
vert_file=None, vert_file=None,
geom_file=None, geom_file=None,
frag_file=None, frag_file=None,
uniforms=None,
texture_paths=None, texture_paths=None,
depth_test=False, depth_test=False,
render_primative=moderngl.TRIANGLE_STRIP, render_primative=moderngl.TRIANGLE_STRIP,
@ -53,11 +52,11 @@ def get_shader_info(data=None,
SHADER_INFO_KEYS, SHADER_INFO_KEYS,
[ [
data, data,
uniforms,
vert_file, vert_file,
geom_file, geom_file,
frag_file, frag_file,
texture_paths or {}, uniforms or dict(),
texture_paths or dict(),
depth_test, depth_test,
str(render_primative) str(render_primative)
] ]
@ -75,19 +74,8 @@ def is_valid_shader_info(shader_info):
def shader_info_to_id(shader_info): def shader_info_to_id(shader_info):
# A unique id for a shader based on the # A unique id for a shader
# files holding its code and texture return "|".join([str(shader_info[key]) for key in SHADER_KEYS_FOR_ID])
tuples = [
(key, shader_info[key])
for key in SHADER_KEYS_FOR_ID
]
return json.dumps(tuples)
def shader_id_to_info(sid):
result = dict(json.loads(sid))
result["data"] = None
return result
def same_shader_type(info1, info2): def same_shader_type(info1, info2):
@ -97,6 +85,14 @@ def same_shader_type(info1, info2):
]) ])
def shader_info_to_program_code(shader_info):
return {
"vertex_shader": get_shader_code_from_file(shader_info["vert"]),
"geometry_shader": get_shader_code_from_file(shader_info["geom"]),
"fragment_shader": get_shader_code_from_file(shader_info["frag"]),
}
def get_shader_code_from_file(filename): def get_shader_code_from_file(filename):
if not filename: if not filename:
return None return None