diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index edd824a3..5e0935c0 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -12,9 +12,6 @@ from manimlib.mobject.mobject import Point from manimlib.utils.config_ops import digest_config from manimlib.utils.bezier import interpolate from manimlib.utils.simple_functions import fdiv -from manimlib.utils.shaders import shader_info_to_id -from manimlib.utils.shaders import shader_info_program_id -from manimlib.utils.shaders import shader_info_to_program_code from manimlib.utils.simple_functions import clip from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import rotation_matrix_transpose_from_quaternion @@ -168,7 +165,7 @@ class Camera(object): self.init_textures() self.init_light_source() self.refresh_perspective_uniforms() - self.static_mobjects_to_shader_info_list = {} + self.static_mobject_to_render_group_list = {} def init_frame(self): self.frame = CameraFrame(**self.frame_config) @@ -321,91 +318,92 @@ class Camera(object): self.refresh_perspective_uniforms() for mobject in mobjects: try: - info_list = self.static_mobjects_to_shader_info_list[id(mobject)] + rg_list = self.static_mobject_to_render_group_list[id(mobject)] + release_when_done = False except KeyError: - info_list = mobject.get_shader_info_list() + rg_list = map(self.get_render_group, mobject.get_shader_wrapper_list()) + release_when_done = True - for shader_info in info_list: - self.render(shader_info) + for render_group in rg_list: + self.render(render_group, release_when_done) - def render(self, shader_info): - cached_buffers = "render_group" in shader_info - if cached_buffers: - vbo, ibo, vao, shader = shader_info["render_group"] - else: - vbo, ibo, vao, shader = self.get_render_group(shader_info) + def render(self, render_group, release_when_done=True): + shader_wrapper = render_group["shader_wrapper"] + shader_program = render_group["prog"] + self.set_shader_uniforms(shader_program, shader_wrapper) + self.update_depth_test(shader_wrapper) + render_group["vao"].render(int(shader_wrapper.render_primative)) + if release_when_done: + self.release_render_group(render_group) - self.set_shader_uniforms(shader, shader_info) - - if shader_info["depth_test"]: + def update_depth_test(self, shader_wrapper): + if shader_wrapper.depth_test: self.ctx.enable(moderngl.DEPTH_TEST) else: self.ctx.disable(moderngl.DEPTH_TEST) - vao.render(int(shader_info["render_primative"])) - - if not cached_buffers: - self.release_gl_objects(vbo, ibo, vao) - - def get_render_group(self, shader_info): - shader, vert_format = self.get_shader(shader_info) - # vbo = self.ctx.buffer(shader_info["vert_data"].tobytes()) - vbo = self.ctx.buffer(shader_info["vert_data"]) - - vert_indices = shader_info["vert_indices"] - if vert_indices is None: - ibo = None - else: - ibo = self.ctx.buffer(vert_indices.astype('i4').tobytes()) - - vao = self.ctx.vertex_array( - program=shader, - content=[(vbo, vert_format, *shader_info["attributes"])], - index_buffer=ibo, - ) - return (vbo, ibo, vao, shader) - def set_mobjects_as_static(self, *mobjects): - # Create buffer and array objects holding each mobjects shader data + # Creates buffer and array objects holding each mobjects shader data for mob in mobjects: - info_list = mob.get_shader_info_list() - for info in info_list: - info["render_group"] = self.get_render_group(info) - self.static_mobjects_to_shader_info_list[id(mob)] = info_list + self.static_mobject_to_render_group_list[id(mob)] = [ + self.get_render_group(sw) + for sw in mob.get_shader_wrapper_list() + ] def release_static_mobjects(self): - for mob, info_list in self.static_mobjects_to_shader_info_list.items(): - for info in info_list: - self.release_gl_objects(*info["render_group"][:3]) - self.static_mobjects_to_shader_info_list = {} + for rg_list in self.static_mobject_to_render_group_list.values(): + for render_group in rg_list: + self.release_render_group(render_group) + self.static_mobject_to_render_group_list = {} - def release_gl_objects(self, *objs): - for obj in objs: - if obj: - obj.release() + def get_render_group(self, shader_wrapper): + # Data buffers + vbo = self.ctx.buffer(shader_wrapper.vert_data.tobytes()) + if shader_wrapper.vert_indices is None: + ibo = None + else: + ibo = self.ctx.buffer(shader_wrapper.vert_indices.astype('i4').tobytes()) + + # Program and vertex array + shader_program, vert_format = self.get_shader_program(shader_wrapper) + vao = self.ctx.vertex_array( + program=shader_program, + content=[(vbo, vert_format, *shader_wrapper.vert_attributes)], + index_buffer=ibo, + ) + return { + "vbo": vbo, + "ibo": ibo, + "vao": vao, + "prog": shader_program, + "shader_wrapper": shader_wrapper, + } + + def release_render_group(self, render_group): + for key in ["vbo", "ibo", "vao"]: + if render_group[key] is not None: + render_group[key].release() # Shaders def init_shaders(self): # Initialize with the null id going to None - self.id_to_shader = {"": None} + self.id_to_shader_program = {"": None} - def get_shader(self, shader_info): - sid = shader_info_program_id(shader_info) - if sid not in self.id_to_shader: + def get_shader_program(self, shader_wrapper): + sid = shader_wrapper.get_program_id() + if sid not in self.id_to_shader_program: # Create shader program for the first time, then cache - # in the id_to_shader dictionary - program = self.ctx.program(**shader_info_to_program_code(shader_info)) - vert_format = moderngl.detect_format(program, shader_info["attributes"]) - self.id_to_shader[sid] = (program, vert_format) - program, vert_format = self.id_to_shader[sid] - # self.set_shader_uniforms(program, shader_info) - return program, vert_format + # in the id_to_shader_program dictionary + program = self.ctx.program(**shader_wrapper.get_program_code()) + vert_format = moderngl.detect_format(program, shader_wrapper.vert_attributes) + self.id_to_shader_program[sid] = (program, vert_format) + return self.id_to_shader_program[sid] - def set_shader_uniforms(self, shader, shader_info): - for name, path in shader_info["texture_paths"].items(): + def set_shader_uniforms(self, shader, shader_wrapper): + for name, path in shader_wrapper.texture_paths.items(): tid = self.get_texture_id(path) shader[name].value = tid - for name, value in it.chain(shader_info["uniforms"].items(), self.perspective_uniforms.items()): + for name, value in it.chain(shader_wrapper.uniforms.items(), self.perspective_uniforms.items()): try: shader[name].value = value except KeyError: diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 83e95e2a..2df95afb 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -22,10 +22,7 @@ from manimlib.utils.simple_functions import get_parameters from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix_transpose -from manimlib.utils.shaders import refresh_shader_info_id -from manimlib.utils.shaders import get_shader_info -from manimlib.utils.shaders import shader_info_to_id -from manimlib.utils.shaders import is_valid_shader_info +from manimlib.utils.shaders import ShaderWrapper # TODO: Explain array_attrs @@ -217,6 +214,8 @@ class Mobject(Container): setattr(copy_mobject, attr, value.copy()) if isinstance(value, np.ndarray): setattr(copy_mobject, attr, np.array(value)) + if isinstance(value, ShaderWrapper): + setattr(copy_mobject, attr, value.copy()) return copy_mobject def deepcopy(self): @@ -235,6 +234,7 @@ class Mobject(Container): return self.target # Updating + def init_updaters(self): self.time_based_updaters = [] self.non_time_updaters = [] @@ -1194,7 +1194,7 @@ class Mobject(Container): def wrapper(self): for mob in self.get_family(): func(mob) - mob.refresh_shader_info_template_id() + mob.refresh_shader_wrapper_id() return wrapper @affects_shader_info_id @@ -1221,8 +1221,8 @@ class Mobject(Container): def init_shader_data(self): self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype) self.shader_indices = None - self.shader_info_template = get_shader_info( - attributes=self.shader_data.dtype.names, + self.shader_wrapper = ShaderWrapper( + vert_data=self.shader_data, vert_file=self.vert_shader_file, geom_file=self.geom_shader_file, frag_file=self.frag_shader_file, @@ -1231,8 +1231,8 @@ class Mobject(Container): render_primative=self.render_primative, ) - def refresh_shader_info_template_id(self): - refresh_shader_info_id(self.shader_info_template) + def refresh_shader_wrapper_id(self): + self.shader_wrapper.refresh_id() return self def get_blank_shader_data_array(self, size, name="shader_data"): @@ -1245,41 +1245,30 @@ class Mobject(Container): return new_arr return arr - def get_shader_info_list(self): - shader_infos = it.chain( - [self.get_shader_info()], - *[sm.get_shader_info_list() for sm in self.submobjects] + def get_shader_wrapper(self): + self.shader_wrapper.vert_data = self.get_shader_data() + self.shader_wrapper.vert_indices = self.get_shader_vert_indices() + self.shader_wrapper.uniforms = self.get_shader_uniforms() + self.shader_wrapper.depth_test = self.depth_test + return self.shader_wrapper + + def get_shader_wrapper_list(self): + shader_wrappers = it.chain( + [self.get_shader_wrapper()], + *[sm.get_shader_wrapper_list() for sm in self.submobjects] ) - batches = batch_by_property(shader_infos, shader_info_to_id) + batches = batch_by_property(shader_wrappers, lambda sw: sw.get_id()) result = [] - for info_group, sid in batches: - combined_info = info_group[0] - if not is_valid_shader_info(combined_info): + for wrapper_group, sid in batches: + shader_wrapper = wrapper_group[0] + if not shader_wrapper.is_valid(): continue - data_list = [] - indices_list = [] - num_verts = 0 - for info in info_group: - data_list.append(info["vert_data"]) - if info["vert_indices"] is not None: - indices_list.append(info["vert_indices"] + num_verts) - num_verts += len(info["vert_data"]) - # Combine lists - combined_info["vert_data"] = np.hstack(data_list) - if combined_info["vert_indices"] is not None: - combined_info["vert_indices"] = np.hstack(indices_list) - if len(combined_info["vert_indices"]) > 0: - result.append(combined_info) + shader_wrapper.combine_with(*wrapper_group[1:]) + if len(shader_wrapper.vert_data) > 0: + result.append(shader_wrapper) return result - def get_shader_info(self): - shader_info = dict(self.shader_info_template) - shader_info["vert_data"] = self.get_shader_data() - shader_info["vert_indices"] = self.get_shader_vert_indices() - shader_info["uniforms"] = self.get_shader_uniforms() - return shader_info - def get_shader_uniforms(self): return { "is_fixed_in_frame": float(self.is_fixed_in_frame), diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index dcad3e75..266897c1 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -30,24 +30,14 @@ class ImageMobject(Mobject): def init_points(self): self.points = np.array([UL, DL, UR, DR]) + self.im_coords = np.array([(0, 0), (0, 1), (1, 0), (1, 1)]) size = self.image.size self.set_width(2 * size[0] / size[1], stretch=True) self.set_height(self.height) - self.im_coords = np.array( - [(0, 0), (0, 1), (1, 0), (1, 1)] - ) - def init_colors(self): self.set_opacity(self.opacity) - def get_shader_data(self): - data = self.get_blank_shader_data_array(len(self.points)) - data["point"] = self.points - data["im_coords"] = self.im_coords - data["opacity"] = self.opacity - return data - def set_opacity(self, alpha, family=True): opacity = listify(alpha) diff = 4 - len(opacity) @@ -67,3 +57,10 @@ class ImageMobject(Mobject): self.opacity = interpolate( mobject1.opacity, mobject2.opacity, alpha ) + + def get_shader_data(self): + data = self.get_blank_shader_data_array(len(self.points)) + data["point"] = self.points + data["im_coords"] = self.im_coords + data["opacity"] = self.opacity + return data diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 5203ccb5..5e93b774 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -27,8 +27,7 @@ from manimlib.utils.space_ops import earclip_triangulation from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_unit_normal from manimlib.utils.space_ops import z_to_vector -from manimlib.utils.shaders import refresh_shader_info_id -from manimlib.utils.shaders import get_shader_info +from manimlib.utils.shaders import ShaderWrapper class VMobject(Mobject): @@ -882,70 +881,67 @@ class VMobject(Mobject): def init_shader_data(self): self.fill_data = np.zeros(len(self.points), dtype=self.fill_dtype) self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype) - self.fill_shader_info_template = get_shader_info( - attributes=self.fill_data.dtype.names, + self.fill_shader_wrapper = ShaderWrapper( + vert_data=self.fill_data, + vert_indices=np.zeros(0, dtype='i4'), vert_file=self.fill_vert_shader_file, geom_file=self.fill_geom_shader_file, frag_file=self.fill_frag_shader_file, - depth_test=self.depth_test, render_primative=self.render_primative, ) - self.stroke_shader_info_template = get_shader_info( - attributes=self.stroke_data.dtype.names, + self.stroke_shader_wrapper = ShaderWrapper( + vert_data=self.stroke_data, vert_file=self.stroke_vert_shader_file, geom_file=self.stroke_geom_shader_file, frag_file=self.stroke_frag_shader_file, - depth_test=self.depth_test, render_primative=self.render_primative, ) - def refresh_shader_info_template_id(self): - for template in [self.fill_shader_info_template, self.stroke_shader_info_template]: - refresh_shader_info_id(template) + def refresh_shader_wrapper_id(self): + for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]: + wrapper.refresh_id() return self - def get_shader_info_list(self): - fill_info = dict(self.fill_shader_info_template) - stroke_info = dict(self.stroke_shader_info_template) - fill_info["uniforms"] = self.get_shader_uniforms() - stroke_info["uniforms"] = self.get_stroke_uniforms() - for info in fill_info, stroke_info: - info["depth_test"] = self.depth_test + def get_fill_shader_wrapper(self): + self.fill_shader_wrapper.vert_data = self.get_fill_shader_data() + self.fill_shader_wrapper.vert_indices = self.get_fill_shader_vert_indices() + self.fill_shader_wrapper.uniforms = self.get_shader_uniforms() + self.fill_shader_wrapper.depth_test = self.depth_test + return self.fill_shader_wrapper + def get_stroke_shader_wrapper(self): + self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() + self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms() + self.stroke_shader_wrapper.depth_test = self.depth_test + return self.stroke_shader_wrapper + + def get_shader_wrapper_list(self): # Build up data lists - back_stroke_data = [] - stroke_data = [] - fill_data = [] - fill_vert_indices = [] - num_fill_verts = 0 # Number of fill verts + fill_shader_wrappers = [] + stroke_shader_wrappers = [] + back_stroke_shader_wrappers = [] for submob in self.family_members_with_points(): if submob.has_fill(): - data = submob.get_fill_shader_data() - indices = submob.get_fill_shader_vert_indices() + num_fill_verts - num_fill_verts += len(data) - - fill_data.append(data) - fill_vert_indices.append(indices) + fill_shader_wrappers.append(submob.get_fill_shader_wrapper()) if submob.has_stroke(): - data = submob.get_stroke_shader_data() + ssw = submob.get_stroke_shader_wrapper() if submob.draw_stroke_behind_fill: - back_stroke_data.append(data) + back_stroke_shader_wrappers.append(ssw) else: - stroke_data.append(data) + stroke_shader_wrappers.append(ssw) # Combine data lists + wrapper_lists = [ + back_stroke_shader_wrappers, + fill_shader_wrappers, + stroke_shader_wrappers + ] result = [] - if back_stroke_data: - back_stroke_info = dict(stroke_info) # Copy - back_stroke_info["vert_data"] = np.hstack(back_stroke_data) - result.append(back_stroke_info) - if fill_data: - fill_info["vert_data"] = np.hstack(fill_data) - fill_info["vert_indices"] = np.hstack(fill_vert_indices) - result.append(fill_info) - if stroke_data: - stroke_info["vert_data"] = np.hstack(stroke_data) - result.append(stroke_info) + for wlist in wrapper_lists: + if wlist: + wrapper = wlist[0] + wrapper.combine_with(*wlist[1:]) + result.append(wrapper) return result def get_stroke_uniforms(self): @@ -1013,7 +1009,7 @@ class VMobject(Mobject): return self.saved_triangulation if len(self.points) <= 1: - return [] + return np.zeros(0, dtype='i4') # Rotate points such that unit normal vector is OUT # TODO, 99% of the time this does nothing. Do a check for that? diff --git a/manimlib/utils/shaders.py b/manimlib/utils/shaders.py index 531af64b..ea261ae1 100644 --- a/manimlib/utils/shaders.py +++ b/manimlib/utils/shaders.py @@ -2,6 +2,8 @@ import os import warnings import re import moderngl +import numpy as np +import copy from manimlib.constants import SHADER_DIR @@ -12,111 +14,97 @@ from manimlib.constants import SHADER_DIR # to that shader -# TODO, this should all be treated as an object -# This object a shader program instead of the vert, -# geom and frag file names, and it should cache those -# programs in the way currently handled by Camera -# It should replace the Camera.get_shader method with -# its own get_shader_program method, which will take -# in the camera's perspective_uniforms. +class ShaderWrapper(object): + def __init__(self, + vert_data=None, + vert_indices=None, + vert_file=None, + geom_file=None, + frag_file=None, + uniforms=None, # A dictionary mapping names of uniform variables + texture_paths=None, # A dictionary mapping names to filepaths for textures. + depth_test=False, + render_primative=moderngl.TRIANGLE_STRIP, + ): + self.vert_data = vert_data + self.vert_indices = vert_indices + self.vert_attributes = vert_data.dtype.names + self.vert_file = vert_file + self.geom_file = geom_file + self.frag_file = frag_file + self.uniforms = uniforms or dict() + self.texture_paths = texture_paths or dict() + self.depth_test = depth_test + self.render_primative = str(render_primative) + self.id = self.create_id() + self.program_id = self.create_program_id() + def copy(self): + result = copy.copy(self) + result.vert_data = np.array(self.vert_data) + if result.vert_indices is not None: + result.vert_indices = np.array(self.vert_indices) + if self.uniforms: + result.uniforms = dict(self.uniforms) + if self.texture_paths: + result.texture_paths = dict(self.texture_paths) + return result -SHADER_INFO_KEYS = [ - # Vertex data for the shader (as structured array) - "vert_data", - # Index data (if applicable) for the shader - "index_data", - # List of variable names corresponding to inputs of vertex shader - "attributes", - # Filename of vetex shader - "vert", - # Filename of geometry shader, if there is one - "geom", - # Filename of fragment shader - "frag", - # A dictionary mapping names of uniform variables - "uniforms", - # A dictionary mapping names (as they show up in) - # the shader to filepaths for textures. - "texture_paths", - # Whether or not to apply depth test - "depth_test", - # E.g. moderngl.TRIANGLE_STRIP - "render_primative", -] + def is_valid(self): + return all([ + self.vert_data is not None, + self.vert_file, + self.frag_file, + ]) -# Exclude data -SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[3:] + def get_id(self): + return self.id + def get_program_id(self): + return self.program_id -def get_shader_info(vert_data=None, - vert_indices=None, - attributes=None, - vert_file=None, - geom_file=None, - frag_file=None, - uniforms=None, - texture_paths=None, - depth_test=False, - render_primative=moderngl.TRIANGLE_STRIP, - ): - result = { - "vert_data": vert_data, - "vert_indices": vert_indices, - "attributes": attributes, - "vert": vert_file, - "geom": geom_file, - "frag": frag_file, - "uniforms": uniforms or dict(), - "texture_paths": texture_paths or dict(), - "depth_test": depth_test, - "render_primative": str(render_primative), - } - result["id"] = create_shader_info_id(result) - result["prog_id"] = create_shader_info_program_id(result) - return result + def create_id(self): + # A unique id for a shader + return "|".join(map(str, [ + self.vert_file, + self.geom_file, + self.frag_file, + self.uniforms, + self.texture_paths, + self.depth_test, + self.render_primative, + ])) + def refresh_id(self): + self.id = self.create_id() -def is_valid_shader_info(shader_info): - vert_data = shader_info["vert_data"] - return all([ - vert_data is not None, - shader_info["vert"], - shader_info["frag"], - ]) + def create_program_id(self): + return "|".join(map(str, [self.vert_file, self.geom_file, self.frag_file])) + def get_program_code(self): + return { + "vertex_shader": get_shader_code_from_file(self.vert_file), + "geometry_shader": get_shader_code_from_file(self.geom_file), + "fragment_shader": get_shader_code_from_file(self.frag_file), + } -def shader_info_to_id(shader_info): - return shader_info["id"] - - -def shader_info_program_id(shader_info): - return shader_info["prog_id"] - - -def create_shader_info_id(shader_info): - # A unique id for a shader - return "|".join([str(shader_info[key]) for key in SHADER_KEYS_FOR_ID]) - - -def refresh_shader_info_id(shader_info): - shader_info["id"] = create_shader_info_id(shader_info) - - -def create_shader_info_program_id(shader_info): - return "|".join([str(shader_info[key]) for key in ["vert", "geom", "frag"]]) - - -def same_shader_type(info1, info2): - return info1["id"] == info2["id"] - - -def shader_info_to_program_code(shader_info): - return { - "vertex_shader": get_shader_code_from_file(shader_info["vert"]), - "geometry_shader": get_shader_code_from_file(shader_info["geom"]), - "fragment_shader": get_shader_code_from_file(shader_info["frag"]), - } + def combine_with(self, *shader_wrappers): + # Assume they are of the same type + if len(shader_wrappers) == 0: + return + if self.vert_indices is not None: + num_verts = len(self.vert_data) + indices_list = [self.vert_indices] + data_list = [self.vert_data] + for sw in shader_wrappers: + indices_list.append(sw.vert_indices + num_verts) + data_list.append(sw.vert_data) + num_verts += len(sw.vert_data) + self.vert_indices = np.hstack(indices_list) + self.vert_data = np.hstack(data_list) + else: + self.vert_data = np.hstack([self.vert_data, *[sw.vert_data for sw in shader_wrappers]]) + return self def get_shader_code_from_file(filename):