From 80729c0cb8e719d3aea242518564c2022d61fb27 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Wed, 25 Jan 2023 10:37:12 -0800 Subject: [PATCH] Only initialize ShaderWrappers as needed --- manimlib/mobject/mobject.py | 10 +++++++--- manimlib/mobject/types/vectorized_mobject.py | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 874c81a1..86397bf1 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -101,6 +101,7 @@ class Mobject(object): self.saved_state = None self.target = None self.bounding_box: Vect3Array = np.zeros((3, 3)) + self._shaders_initialized: bool = False self.init_data() self._data_defaults = np.ones(1, dtype=self.data.dtype) @@ -109,7 +110,6 @@ class Mobject(object): self.init_event_listners() self.init_points() self.init_colors() - self.init_shader_data() if self.depth_test: self.apply_depth_test() @@ -1843,7 +1843,6 @@ class Mobject(object): # For shader data def init_shader_data(self): - # TODO, only call this when needed? self.shader_indices = np.zeros(0) self.shader_wrapper = ShaderWrapper( vert_data=self.data, @@ -1854,10 +1853,15 @@ class Mobject(object): ) def refresh_shader_wrapper_id(self): - self.shader_wrapper.refresh_id() + if self._shaders_initialized: + self.shader_wrapper.refresh_id() return self def get_shader_wrapper(self) -> ShaderWrapper: + if not self._shaders_initialized: + self.init_shader_data() + self._shaders_initialized = True + self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.uniforms = self.get_uniforms() diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index aba18c60..8b1f5a58 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1196,6 +1196,10 @@ class VMobject(Mobject): return self def get_shader_wrapper_list(self) -> list[ShaderWrapper]: + if not self._shaders_initialized: + self.init_shader_data() + self._shaders_initialized = True + family = self.family_members_with_points() if not family: return []