Merge pull request #1989 from 3b1b/video-work

A few more fixes and tweaks.
This commit is contained in:
Grant Sanderson
2023-02-04 16:57:13 -08:00
committed by GitHub
15 changed files with 298 additions and 194 deletions

View File

@ -174,16 +174,17 @@ class TexTransformExample(Scene):
self.add(lines[0]) self.add(lines[0])
# The animation TransformMatchingStrings will line up parts # The animation TransformMatchingStrings will line up parts
# of the source and target which have matching substring strings. # of the source and target which have matching substring strings.
# Here, giving it a little path_arc makes each part sort of # Here, giving it a little path_arc makes each part rotate into
# rotate into their final positions, which feels appropriate # their final positions, which feels appropriate for the idea of
# for the idea of rearranging an equation # rearranging an equation
self.play( self.play(
TransformMatchingStrings( TransformMatchingStrings(
lines[0].copy(), lines[1], lines[0].copy(), lines[1],
# matched_keys specifies which substring should # matched_keys specifies which substring should
# line up. If it's not specified, the animation # line up. If it's not specified, the animation
# will try its best, but may not quite give the # will align the longest matching substrings.
# intended effect # In this case, the substring "^2 = C^2" would
# trip it up
matched_keys=["A^2", "B^2", "C^2"], matched_keys=["A^2", "B^2", "C^2"],
# When you want a substring from the source # When you want a substring from the source
# to go to a non-equal substring from the target, # to go to a non-equal substring from the target,
@ -206,25 +207,57 @@ class TexTransformExample(Scene):
), ),
) )
self.wait(2) self.wait(2)
# You can also index into Tex mobject (or other StringMobjects)
# by substrings and regular expressions
top_equation = lines[0]
low_equation = lines[3]
self.play(LaggedStartMap(FlashAround, low_equation["C"], lag_ratio=0.5))
self.play(LaggedStartMap(FlashAround, low_equation["B"], lag_ratio=0.5))
self.play(LaggedStartMap(FlashAround, top_equation[re.compile(r"\w\^2")]))
self.play(Indicate(low_equation[R"\sqrt"]))
self.wait()
self.play(LaggedStartMap(FadeOut, lines, shift=2 * RIGHT)) self.play(LaggedStartMap(FadeOut, lines, shift=2 * RIGHT))
# TransformMatchingShapes will try to line up all pieces of a
# source mobject with those of a target, regardless of the
# what Mobject type they are.
source = Text("the morse code", height=1)
target = Text("here come dots", height=1)
saved_source = source.copy()
self.play(Write(source))
self.wait()
kw = dict(run_time=3, path_arc=PI / 2)
self.play(TransformMatchingShapes(source, target, **kw))
self.wait()
self.play(TransformMatchingShapes(target, saved_source, **kw))
self.wait()
class TexIndexing(Scene):
def construct(self):
# You can index into Tex mobject (or other StringMobjects) by substrings
equation = Tex(R"e^{\pi i} = -1", font_size=144)
self.add(equation)
self.play(FlashAround(equation["e"]))
self.wait()
self.play(Indicate(equation[R"\pi"]))
self.wait()
self.play(TransformFromCopy(
equation[R"e^{\pi i}"].copy().set_opacity(0.5),
equation["-1"],
path_arc=-PI / 2,
run_time=3
))
self.play(FadeOut(equation))
# Or regular expressions
equation = Tex("A^2 + B^2 = C^2", font_size=144)
self.play(Write(equation))
for part in equation[re.compile(r"\w\^2")]:
self.play(FlashAround(part))
self.wait()
self.play(FadeOut(equation))
# Indexing by substrings like this may not work when # Indexing by substrings like this may not work when
# the order in which Latex draws symbols does not match # the order in which Latex draws symbols does not match
# the order in which they show up in the string. # the order in which they show up in the string.
# For example, here the infinity is drawn before the sigma # For example, here the infinity is drawn before the sigma
# so we don't get the desired behavior. # so we don't get the desired behavior.
equation = Tex(R"\sum_{n = 1}^\infty \frac{1}{n^2} = \frac{\pi^2}{6}") equation = Tex(R"\sum_{n = 1}^\infty \frac{1}{n^2} = \frac{\pi^2}{6}", font_size=72)
self.play(FadeIn(equation)) self.play(FadeIn(equation))
self.play(equation[R"\infty"].animate.set_color(RED)) # Doesn't hit the infinity self.play(equation[R"\infty"].animate.set_color(RED)) # Doesn't hit the infinity
self.wait() self.wait()
@ -236,27 +269,14 @@ class TexTransformExample(Scene):
equation = Tex( equation = Tex(
R"\sum_{n = 1}^\infty {1 \over n^2} = {\pi^2 \over 6}", R"\sum_{n = 1}^\infty {1 \over n^2} = {\pi^2 \over 6}",
# Explicitly mark "\infty" as a substring you might want to access # Explicitly mark "\infty" as a substring you might want to access
isolate=[R"\infty"] isolate=[R"\infty"],
font_size=72
) )
self.play(FadeIn(equation)) self.play(FadeIn(equation))
self.play(equation[R"\infty"].animate.set_color(RED)) # Got it! self.play(equation[R"\infty"].animate.set_color(RED)) # Got it!
self.wait() self.wait()
self.play(FadeOut(equation)) self.play(FadeOut(equation))
# TransformMatchingShapes will try to line up all pieces of a
# source mobject with those of a target, regardless of the
# what Mobject type they are.
source = Text("the morse code", height=1)
target = Text("here come dots", height=1)
self.play(Write(source))
self.wait()
kw = dict(run_time=3, path_arc=PI / 2)
self.play(TransformMatchingShapes(source, target, **kw))
self.wait()
self.play(TransformMatchingShapes(target, source, **kw))
self.wait()
class UpdatersExample(Scene): class UpdatersExample(Scene):
def construct(self): def construct(self):

View File

@ -165,7 +165,7 @@ class LaggedStart(AnimationGroup):
class LaggedStartMap(LaggedStart): class LaggedStartMap(LaggedStart):
def __init__( def __init__(
self, self,
AnimationClass: type, anim_func: Callable[[Mobject], Animation],
group: Mobject, group: Mobject,
arg_creator: Callable[[Mobject], tuple] | None = None, arg_creator: Callable[[Mobject], tuple] | None = None,
run_time: float = 2.0, run_time: float = 2.0,
@ -175,7 +175,7 @@ class LaggedStartMap(LaggedStart):
anim_kwargs = dict(kwargs) anim_kwargs = dict(kwargs)
anim_kwargs.pop("lag_ratio", None) anim_kwargs.pop("lag_ratio", None)
super().__init__( super().__init__(
*(AnimationClass(submob, **anim_kwargs) for submob in group), *(anim_func(submob, **anim_kwargs) for submob in group),
run_time=run_time, run_time=run_time,
lag_ratio=lag_ratio, lag_ratio=lag_ratio,
) )

View File

@ -74,8 +74,6 @@ class Transform(Animation):
def finish(self) -> None: def finish(self) -> None:
super().finish() super().finish()
self.mobject.unlock_data() self.mobject.unlock_data()
if self.target_mobject is not None and self.rate_func(1) == 1:
self.mobject.become(self.target_mobject)
def create_target(self) -> Mobject: def create_target(self) -> Mobject:
# Has no meaningful effect here, but may be useful # Has no meaningful effect here, but may be useful

View File

@ -1,19 +1,15 @@
from __future__ import annotations from __future__ import annotations
import itertools as it import itertools as it
from difflib import SequenceMatcher
import numpy as np
from manimlib.animation.composition import AnimationGroup from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import Transform from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.string_mobject import StringMobject
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -131,28 +127,64 @@ class TransformMatchingStrings(TransformMatchingParts):
target: StringMobject, target: StringMobject,
matched_keys: Iterable[str] = [], matched_keys: Iterable[str] = [],
key_map: dict[str, str] = dict(), key_map: dict[str, str] = dict(),
matched_pairs: Iterable[tuple[Mobject, Mobject]] = [], matched_pairs: Iterable[tuple[VMobject, VMobject]] = [],
**kwargs, **kwargs,
): ):
matched_pairs = list(matched_pairs) + [ matched_pairs = [
*[(source[key], target[key]) for key in matched_keys], *matched_pairs,
*[(source[key1], target[key2]) for key1, key2 in key_map.items()], *self.matching_blocks(source, target, matched_keys, key_map),
*[
(source[substr], target[substr])
for substr in [
*source.get_specified_substrings(),
*target.get_specified_substrings(),
*source.get_symbol_substrings(),
*target.get_symbol_substrings(),
]
]
] ]
super().__init__( super().__init__(
source, target, source, target,
matched_pairs=matched_pairs, matched_pairs=matched_pairs,
**kwargs, **kwargs,
) )
def matching_blocks(
self,
source: StringMobject,
target: StringMobject,
matched_keys: Iterable[str],
key_map: dict[str, str]
) -> list[tuple[VMobject, VMobject]]:
syms1 = source.get_symbol_substrings()
syms2 = target.get_symbol_substrings()
counts1 = list(map(source.substr_to_path_count, syms1))
counts2 = list(map(target.substr_to_path_count, syms2))
# Start with user specified matches
blocks = [(source[key], target[key]) for key in matched_keys]
blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()]
# Nullify any intersections with those matches in the two symbol lists
for sub_source, sub_target in blocks:
for i in range(len(syms1)):
if source[i] in sub_source.family_members_with_points():
syms1[i] = "Null1"
for j in range(len(syms2)):
if target[j] in sub_target.family_members_with_points():
syms2[j] = "Null2"
# Group together longest matching substrings
while True:
matcher = SequenceMatcher(None, syms1, syms2)
match = matcher.find_longest_match(0, len(syms1), 0, len(syms2))
if match.size == 0:
break
i1 = sum(counts1[:match.a])
i2 = sum(counts2[:match.b])
size = sum(counts1[match.a:match.a + match.size])
blocks.append((source[i1:i1 + size], target[i2:i2 + size]))
for i in range(match.size):
syms1[match.a + i] = "Null1"
syms2[match.b + i] = "Null2"
return blocks
class TransformMatchingTex(TransformMatchingStrings): class TransformMatchingTex(TransformMatchingStrings):
"""Alias for TransformMatchingStrings""" """Alias for TransformMatchingStrings"""

View File

@ -41,11 +41,6 @@ class CameraFrame(Mobject):
self.set_height(frame_shape[1], stretch=True) self.set_height(frame_shape[1], stretch=True)
self.move_to(center_point) self.move_to(center_point)
def note_changed_data(self, recurse_up: bool = True):
super().note_changed_data(recurse_up)
self.get_view_matrix(refresh=True)
self.get_implied_camera_location(refresh=True)
def set_orientation(self, rotation: Rotation): def set_orientation(self, rotation: Rotation):
self.uniforms["orientation"][:] = rotation.as_quat() self.uniforms["orientation"][:] = rotation.as_quat()
return self return self
@ -89,7 +84,7 @@ class CameraFrame(Mobject):
Returns a 4x4 for the affine transformation mapping a point Returns a 4x4 for the affine transformation mapping a point
into the camera's internal coordinate system into the camera's internal coordinate system
""" """
if refresh: if self._data_has_changed:
shift = np.identity(4) shift = np.identity(4)
rotation = np.identity(4) rotation = np.identity(4)
scale_mat = np.identity(4) scale_mat = np.identity(4)
@ -169,10 +164,12 @@ class CameraFrame(Mobject):
self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2]) self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2])
return self return self
@Mobject.affects_data
def set_focal_distance(self, focal_distance: float): def set_focal_distance(self, focal_distance: float):
self.uniforms["fovy"] = 2 * math.atan(0.5 * self.get_height() / focal_distance) self.uniforms["fovy"] = 2 * math.atan(0.5 * self.get_height() / focal_distance)
return self return self
@Mobject.affects_data
def set_field_of_view(self, field_of_view: float): def set_field_of_view(self, field_of_view: float):
self.uniforms["fovy"] = field_of_view self.uniforms["fovy"] = field_of_view
return self return self
@ -202,8 +199,8 @@ class CameraFrame(Mobject):
def get_field_of_view(self) -> float: def get_field_of_view(self) -> float:
return self.uniforms["fovy"] return self.uniforms["fovy"]
def get_implied_camera_location(self, refresh=False) -> np.ndarray: def get_implied_camera_location(self) -> np.ndarray:
if refresh: if self._data_has_changed:
to_camera = self.get_inverse_camera_rotation_matrix()[2] to_camera = self.get_inverse_camera_rotation_matrix()[2]
dist = self.get_focal_distance() dist = self.get_focal_distance()
self.camera_location = self.get_center() + dist * to_camera self.camera_location = self.get_center() + dist * to_camera

View File

@ -93,6 +93,14 @@ def parse_cli():
action="store_true", action="store_true",
help="Render to a movie file with an alpha channel", help="Render to a movie file with an alpha channel",
) )
parser.add_argument(
"--vcodec",
help="Video codec to use with ffmpeg",
)
parser.add_argument(
"--pix_fmt",
help="Pixel format to use for the output of ffmpeg, defaults to `yuv420p`",
)
parser.add_argument( parser.add_argument(
"-q", "--quiet", "-q", "--quiet",
action="store_true", action="store_true",
@ -160,6 +168,12 @@ def parse_cli():
action="store_true", action="store_true",
help="Show progress bar for each animation", help="Show progress bar for each animation",
) )
parser.add_argument(
"--prerun",
action="store_true",
help="Calculate total framecount, to display in a progress bar, by doing " + \
"an initial run of the scene which skips animations."
)
parser.add_argument( parser.add_argument(
"--video_dir", "--video_dir",
help="Directory to write video", help="Directory to write video",
@ -386,7 +400,7 @@ def get_output_directory(args: Namespace, custom_config: dict) -> str:
def get_file_writer_config(args: Namespace, custom_config: dict) -> dict: def get_file_writer_config(args: Namespace, custom_config: dict) -> dict:
return { result = {
"write_to_movie": not args.skip_animations and args.write_file, "write_to_movie": not args.skip_animations and args.write_file,
"break_into_partial_movies": custom_config["break_into_partial_movies"], "break_into_partial_movies": custom_config["break_into_partial_movies"],
"save_last_frame": args.skip_animations and args.write_file, "save_last_frame": args.skip_animations and args.write_file,
@ -402,6 +416,18 @@ def get_file_writer_config(args: Namespace, custom_config: dict) -> dict:
"quiet": args.quiet, "quiet": args.quiet,
} }
if args.vcodec:
result["video_codec"] = args.vcodec
elif args.transparent:
result["video_codec"] = 'prores_ks'
elif args.gif:
result["video_codec"] = ''
if args.pix_fmt:
result["pix_fmt"] = args.pix_fmt
return result
def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict: def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict:
# Default to making window half the screen size # Default to making window half the screen size
@ -489,6 +515,7 @@ def get_configuration(args: Namespace) -> dict:
"presenter_mode": args.presenter_mode, "presenter_mode": args.presenter_mode,
"leave_progress_bars": args.leave_progress_bars, "leave_progress_bars": args.leave_progress_bars,
"show_animation_progress": args.show_animation_progress, "show_animation_progress": args.show_animation_progress,
"prerun": args.prerun,
"embed_exception_mode": custom_config["embed_exception_mode"], "embed_exception_mode": custom_config["embed_exception_mode"],
"embed_error_sound": custom_config["embed_error_sound"], "embed_error_sound": custom_config["embed_error_sound"],
} }

View File

@ -79,40 +79,35 @@ def compute_total_frames(scene_class, scene_config):
return int(total_time * scene_config["camera_config"]["fps"]) return int(total_time * scene_config["camera_config"]["fps"])
def get_scenes_to_render(scene_classes, scene_config, config): def scene_from_class(scene_class, scene_config, config):
if config["write_all"]: fw_config = scene_config["file_writer_config"]
return [sc(**scene_config) for sc in scene_classes] if fw_config["write_to_movie"] and config["prerun"]:
fw_config["total_frames"] = compute_total_frames(scene_class, scene_config)
return scene_class(**scene_config)
result = []
for scene_name in config["scene_names"]: def get_scenes_to_render(all_scene_classes, scene_config, config):
found = False if config["write_all"]:
for scene_class in scene_classes: return [sc(**scene_config) for sc in all_scene_classes]
if scene_class.__name__ == scene_name:
fw_config = scene_config["file_writer_config"] names_to_classes = {sc.__name__ : sc for sc in all_scene_classes}
if fw_config["write_to_movie"]: scene_names = config["scene_names"]
fw_config["total_frames"] = compute_total_frames(scene_class, scene_config)
scene = scene_class(**scene_config) for name in set.difference(set(scene_names), names_to_classes):
result.append(scene) log.error(f"No scene named {name} found")
found = True scene_names.remove(name)
break
if not found and (scene_name != ""): if scene_names:
log.error(f"No scene named {scene_name} found") classes_to_run = [names_to_classes[name] for name in scene_names]
if result: elif len(all_scene_classes) == 1:
return result classes_to_run = [all_scene_classes[0]]
# another case
result=[]
if len(scene_classes) == 1:
scene_classes = [scene_classes[0]]
else: else:
scene_classes = prompt_user_for_choice(scene_classes) classes_to_run = prompt_user_for_choice(all_scene_classes)
for scene_class in scene_classes:
fw_config = scene_config["file_writer_config"] return [
if fw_config["write_to_movie"]: scene_from_class(scene_class, scene_config, config)
fw_config["total_frames"] = compute_total_frames(scene_class, scene_config) for scene_class in classes_to_run
scene = scene_class(**scene_config) ]
result.append(scene)
return result
def get_scene_classes_from_module(module): def get_scene_classes_from_module(module):

View File

@ -105,6 +105,7 @@ class Mobject(object):
self.bounding_box: Vect3Array = np.zeros((3, 3)) self.bounding_box: Vect3Array = np.zeros((3, 3))
self._shaders_initialized: bool = False self._shaders_initialized: bool = False
self._data_has_changed: bool = True self._data_has_changed: bool = True
self.shader_code_replacements: dict[str, str] = dict()
self.init_data() self.init_data()
self._data_defaults = np.ones(1, dtype=self.data.dtype) self._data_defaults = np.ones(1, dtype=self.data.dtype)
@ -738,7 +739,7 @@ class Mobject(object):
) )
if len(points1) != len(points2): if len(points1) != len(points2):
return False return False
return bool(np.isclose(points1, points2).all()) return bool(np.isclose(points1, points2, atol=self.get_width() * 1e-2).all())
# Creating new Mobjects from this one # Creating new Mobjects from this one
@ -1895,12 +1896,12 @@ class Mobject(object):
# Shader code manipulation # Shader code manipulation
@affects_data
def replace_shader_code(self, old: str, new: str) -> Self: def replace_shader_code(self, old: str, new: str) -> Self:
# TODO, will this work with VMobject structure, given self.shader_code_replacements[old] = new
# that it does not simpler return shader_wrappers of self._shaders_initialized = False
# family? for mob in self.get_ancestors():
for wrapper in self.get_shader_wrapper_list(): mob._shaders_initialized = False
wrapper.replace_code(old, new)
return self return self
def set_color_by_code(self, glsl_code: str) -> Self: def set_color_by_code(self, glsl_code: str) -> Self:
@ -1967,8 +1968,10 @@ class Mobject(object):
self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.update_program_uniforms(self.get_uniforms()) self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test self.shader_wrapper.depth_test = self.depth_test
for old, new in self.shader_code_replacements.items():
self.shader_wrapper.replace_code(old, new)
return self.shader_wrapper return self.shader_wrapper
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
@ -2004,9 +2007,7 @@ class Mobject(object):
shader_wrapper.generate_vao() shader_wrapper.generate_vao()
self._data_has_changed = False self._data_has_changed = False
for shader_wrapper in self.shader_wrappers: for shader_wrapper in self.shader_wrappers:
shader_wrapper.depth_test = self.depth_test shader_wrapper.update_program_uniforms(camera_uniforms)
shader_wrapper.update_program_uniforms(self.get_uniforms())
shader_wrapper.update_program_uniforms(camera_uniforms, universal=True)
shader_wrapper.pre_render() shader_wrapper.pre_render()
shader_wrapper.render() shader_wrapper.render()

View File

@ -166,20 +166,12 @@ class VMobject(Mobject):
def set_rgba_array( def set_rgba_array(
self, self,
rgba_array: Vect4Array, rgba_array: Vect4Array,
name: str | None = None, name: str = "stroke_rgba",
recurse: bool = False recurse: bool = False
) -> Self: ) -> Self:
if name is None: super().set_rgba_array(rgba_array, name, recurse)
names = ["fill_rgba", "stroke_rgba"] self.note_changed_fill()
else: self.note_changed_stroke()
names = [name]
for name in names:
super().set_rgba_array(rgba_array, name, recurse)
if name == "fill_rgba":
self.note_changed_fill()
elif name == "stroke_rgba":
self.note_changed_stroke()
return self return self
def set_fill( def set_fill(
@ -1262,11 +1254,10 @@ class VMobject(Mobject):
def set_animating_status(self, is_animating: bool, recurse: bool = True): def set_animating_status(self, is_animating: bool, recurse: bool = True):
super().set_animating_status(is_animating, recurse) super().set_animating_status(is_animating, recurse)
if is_animating: for submob in self.get_family(recurse):
for submob in self.get_family(recurse): submob.get_joint_products(refresh=True)
submob.get_joint_products(refresh=True) if not submob._use_winding_fill:
if not submob._use_winding_fill: submob.get_triangulation()
submob.get_triangulation()
return self return self
# For shaders # For shaders
@ -1284,14 +1275,14 @@ class VMobject(Mobject):
self.fill_shader_wrapper = FillShaderWrapper( self.fill_shader_wrapper = FillShaderWrapper(
ctx=ctx, ctx=ctx,
vert_data=fill_data, vert_data=fill_data,
uniforms=self.uniforms, mobject_uniforms=self.uniforms,
shader_folder=self.fill_shader_folder, shader_folder=self.fill_shader_folder,
render_primitive=self.fill_render_primitive, render_primitive=self.fill_render_primitive,
) )
self.stroke_shader_wrapper = ShaderWrapper( self.stroke_shader_wrapper = ShaderWrapper(
ctx=ctx, ctx=ctx,
vert_data=stroke_data, vert_data=stroke_data,
uniforms=self.uniforms, mobject_uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder, shader_folder=self.stroke_shader_folder,
render_primitive=self.stroke_render_primitive, render_primitive=self.stroke_render_primitive,
) )
@ -1301,6 +1292,11 @@ class VMobject(Mobject):
self.fill_shader_wrapper, self.fill_shader_wrapper,
self.stroke_shader_wrapper, self.stroke_shader_wrapper,
] ]
for sw in self.shader_wrappers:
family = self.family_members_with_points()
rep = family[0] if family else self
for old, new in rep.shader_code_replacements.items():
sw.replace_code(old, new)
def refresh_shader_wrapper_id(self) -> Self: def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized: if not self._shaders_initialized:
@ -1309,11 +1305,6 @@ class VMobject(Mobject):
wrapper.refresh_id() wrapper.refresh_id()
return self return self
def get_uniforms(self):
# TODO, account for submob uniforms separately?
self.uniforms.update(self.family_members_with_points()[0].uniforms)
return self.uniforms
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
if not self._shaders_initialized: if not self._shaders_initialized:
self.init_shader_data(ctx) self.init_shader_data(ctx)
@ -1325,32 +1316,25 @@ class VMobject(Mobject):
fill_names = self.fill_data_names fill_names = self.fill_data_names
stroke_names = self.stroke_data_names stroke_names = self.stroke_data_names
# Build up data lists fill_family = (sm for sm in family if sm._has_fill)
stroke_family = (sm for sm in family if sm._has_stroke)
# Build up fill data lists
fill_datas = [] fill_datas = []
fill_indices = [] fill_indices = []
fill_border_datas = [] fill_border_datas = []
stroke_datas = [] for submob in fill_family:
back_stroke_datas = []
for submob in family:
submob.get_joint_products()
indices = submob.get_outer_vert_indices() indices = submob.get_outer_vert_indices()
has_fill = submob._has_fill if submob._use_winding_fill:
has_stroke = submob._has_stroke
back_stroke = has_stroke and submob.stroke_behind
front_stroke = has_stroke and not submob.stroke_behind
if back_stroke:
back_stroke_datas.append(submob.data[stroke_names][indices])
if front_stroke:
stroke_datas.append(submob.data[stroke_names][indices])
if has_fill and submob._use_winding_fill:
data = submob.data[fill_names] data = submob.data[fill_names]
data["base_point"][:] = data["point"][0] data["base_point"][:] = data["point"][0]
fill_datas.append(data[indices]) fill_datas.append(data[indices])
if has_fill and not submob._use_winding_fill: else:
fill_datas.append(submob.data[fill_names]) fill_datas.append(submob.data[fill_names])
fill_indices.append(submob.get_triangulation()) fill_indices.append(submob.get_triangulation())
if has_fill and not front_stroke: if (not submob._has_stroke) or submob.stroke_behind:
# Add fill border # Add fill border
submob.get_joint_products()
names = list(stroke_names) names = list(stroke_names)
names[names.index('stroke_rgba')] = 'fill_rgba' names[names.index('stroke_rgba')] = 'fill_rgba'
names[names.index('stroke_width')] = 'fill_border_width' names[names.index('stroke_width')] = 'fill_border_width'
@ -1359,11 +1343,26 @@ class VMobject(Mobject):
) )
fill_border_datas.append(border_stroke_data[indices]) fill_border_datas.append(border_stroke_data[indices])
# Build up stroke data lists
stroke_datas = []
back_stroke_datas = []
for submob in stroke_family:
submob.get_joint_products()
indices = submob.get_outer_vert_indices()
if submob.stroke_behind:
back_stroke_datas.append(submob.data[stroke_names][indices])
else:
stroke_datas.append(submob.data[stroke_names][indices])
shader_wrappers = [ shader_wrappers = [
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]), self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None), self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
self.stroke_shader_wrapper.read_in(stroke_datas), self.stroke_shader_wrapper.read_in(stroke_datas),
] ]
for sw in shader_wrappers:
rep = family[0] # Representative family member
sw.bind_to_mobject_uniforms(rep.get_uniforms())
sw.depth_test = rep.depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
@ -1371,6 +1370,8 @@ class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs): def __init__(self, *vmobjects: VMobject, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*vmobjects) self.add(*vmobjects)
if vmobjects:
self.uniforms.update(vmobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self: def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject)) assert(isinstance(other, VMobject))

View File

@ -13,6 +13,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.bezier import inverse_interpolate
from manimlib.utils.color import get_colormap_list from manimlib.utils.color import get_colormap_list
from manimlib.utils.color import rgb_to_color
from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import linear
from manimlib.utils.simple_functions import sigmoid from manimlib.utils.simple_functions import sigmoid
@ -173,7 +174,10 @@ class VectorField(VGroup):
**vector_config **vector_config
) )
vect.shift(_input - origin) vect.shift(_input - origin)
vect.set_rgba_array([[*self.value_to_rgb(norm), self.opacity]]) vect.set_color(
rgb_to_color(self.value_to_rgb(norm)),
opacity=self.opacity,
)
return vect return vect

View File

@ -387,7 +387,10 @@ class Scene(object):
same type are grouped together, so this function creates same type are grouped together, so this function creates
Groups of all clusters of adjacent Mobjects in the scene Groups of all clusters of adjacent Mobjects in the scene
""" """
batches = batch_by_property(self.mobjects, lambda m: str(type(m))) batches = batch_by_property(
self.mobjects,
lambda m: str(type(m)) + str(m.get_uniforms())
)
for group in self.render_groups: for group in self.render_groups:
group.clear() group.clear()
@ -554,6 +557,7 @@ class Scene(object):
leave=self.leave_progress_bars, leave=self.leave_progress_bars,
ascii=True if platform.system() == 'Windows' else None, ascii=True if platform.system() == 'Windows' else None,
desc=desc, desc=desc,
bar_format="{l_bar} {n_fmt:3}/{total_fmt:3} {rate_fmt}{postfix}",
) )
else: else:
return times return times
@ -723,6 +727,7 @@ class Scene(object):
def get_state(self) -> SceneState: def get_state(self) -> SceneState:
return SceneState(self) return SceneState(self)
@affects_mobject_list
def restore_state(self, scene_state: SceneState): def restore_state(self, scene_state: SceneState):
scene_state.restore_scene(self) scene_state.restore_scene(self)

View File

@ -47,6 +47,8 @@ class SceneFileWriter(object):
quiet: bool = False, quiet: bool = False,
total_frames: int = 0, total_frames: int = 0,
progress_description_len: int = 40, progress_description_len: int = 40,
video_codec: str = "libx264",
pixel_format: str = "yuv420p",
): ):
self.scene: Scene = scene self.scene: Scene = scene
self.write_to_movie = write_to_movie self.write_to_movie = write_to_movie
@ -63,6 +65,8 @@ class SceneFileWriter(object):
self.quiet = quiet self.quiet = quiet
self.total_frames = total_frames self.total_frames = total_frames
self.progress_description_len = progress_description_len self.progress_description_len = progress_description_len
self.video_codec = video_codec
self.pixel_format = pixel_format
# State during file writing # State during file writing
self.writing_process: sp.Popen | None = None self.writing_process: sp.Popen | None = None
@ -262,32 +266,26 @@ class SceneFileWriter(object):
'-an', # Tells FFMPEG not to expect any audio '-an', # Tells FFMPEG not to expect any audio
'-loglevel', 'error', '-loglevel', 'error',
] ]
if self.movie_file_extension == ".mov": if self.video_codec:
# This is if the background of the exported command += ['-vcodec', self.video_codec]
# video should be transparent. if self.pixel_format:
command += [ command += ['-pix_fmt', self.pixel_format]
'-vcodec', 'prores_ks',
]
elif self.movie_file_extension == ".gif":
command += []
else:
command += [
'-vcodec', 'libx264',
'-pix_fmt', 'yuv420p',
]
command += [self.temp_file_path] command += [self.temp_file_path]
self.writing_process = sp.Popen(command, stdin=sp.PIPE) self.writing_process = sp.Popen(command, stdin=sp.PIPE)
if self.total_frames > 0 and not self.quiet: if not self.quiet:
self.progress_display = ProgressDisplay( self.progress_display = ProgressDisplay(
range(self.total_frames), range(self.total_frames),
# bar_format="{l_bar}{bar}|{n_fmt}/{total_fmt}",
leave=False, leave=False,
ascii=True if platform.system() == 'Windows' else None, ascii=True if platform.system() == 'Windows' else None,
dynamic_ncols=True, dynamic_ncols=True,
) )
self.set_progress_display_description() self.set_progress_display_description()
def use_fast_encoding(self):
self.video_codec = "libx264rgb"
self.pixel_format = "rgb32"
def begin_insert(self): def begin_insert(self):
# Begin writing process # Begin writing process
self.write_to_movie = True self.write_to_movie = True

View File

@ -15,6 +15,7 @@ from manimlib.utils.shaders import image_path_to_texture
from manimlib.utils.shaders import get_texture_id from manimlib.utils.shaders import get_texture_id
from manimlib.utils.shaders import get_fill_canvas from manimlib.utils.shaders import get_fill_canvas
from manimlib.utils.shaders import release_texture from manimlib.utils.shaders import release_texture
from manimlib.utils.shaders import set_program_uniform
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -37,7 +38,7 @@ class ShaderWrapper(object):
vert_data: np.ndarray, vert_data: np.ndarray,
vert_indices: Optional[np.ndarray] = None, vert_indices: Optional[np.ndarray] = None,
shader_folder: Optional[str] = None, shader_folder: Optional[str] = None,
uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False, depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP, render_primitive: int = moderngl.TRIANGLE_STRIP,
@ -47,13 +48,14 @@ class ShaderWrapper(object):
self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder self.shader_folder = shader_folder
self.uniforms: UniformDict = dict()
self.depth_test = depth_test self.depth_test = depth_test
self.render_primitive = render_primitive self.render_primitive = render_primitive
self.program_uniform_mirror: UniformDict = dict()
self.bind_to_mobject_uniforms(mobject_uniforms or dict())
self.init_program_code() self.init_program_code()
self.init_program() self.init_program()
self.update_program_uniforms(uniforms or dict())
if texture_paths is not None: if texture_paths is not None:
self.init_textures(texture_paths) self.init_textures(texture_paths)
self.init_vao() self.init_vao()
@ -91,14 +93,17 @@ class ShaderWrapper(object):
self.ibo = None self.ibo = None
self.vao = None self.vao = None
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
self.mobject_uniforms = mobject_uniforms
def __eq__(self, shader_wrapper: ShaderWrapper): def __eq__(self, shader_wrapper: ShaderWrapper):
return all(( return all((
np.all(self.vert_data == shader_wrapper.vert_data), np.all(self.vert_data == shader_wrapper.vert_data),
np.all(self.vert_indices == shader_wrapper.vert_indices), np.all(self.vert_indices == shader_wrapper.vert_indices),
self.shader_folder == shader_wrapper.shader_folder, self.shader_folder == shader_wrapper.shader_folder,
all( all(
self.uniforms[key] == shader_wrapper.uniforms[key] self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
for key in self.uniforms for key in self.mobject_uniforms
), ),
self.depth_test == shader_wrapper.depth_test, self.depth_test == shader_wrapper.depth_test,
self.render_primitive == shader_wrapper.render_primitive, self.render_primitive == shader_wrapper.render_primitive,
@ -122,31 +127,25 @@ class ShaderWrapper(object):
def get_id(self) -> str: def get_id(self) -> str:
return self.id return self.id
def get_program_id(self) -> int:
return self.program_id
def create_id(self) -> str: def create_id(self) -> str:
# A unique id for a shader # A unique id for a shader
program_id = hash("".join(
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
))
return "|".join(map(str, [ return "|".join(map(str, [
self.program_id, program_id,
self.uniforms, self.mobject_uniforms,
self.depth_test, self.depth_test,
self.render_primitive, self.render_primitive,
])) ]))
def refresh_id(self) -> None: def refresh_id(self) -> None:
self.program_id = self.create_program_id()
self.id = self.create_id() self.id = self.create_id()
def create_program_id(self) -> int:
return hash("".join((
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)))
def replace_code(self, old: str, new: str) -> None: def replace_code(self, old: str, new: str) -> None:
code_map = self.program_code code_map = self.program_code
for (name, code) in code_map.items(): for name in code_map:
if code_map[name] is None: if code_map[name] is None:
continue continue
code_map[name] = re.sub(old, new, code_map[name]) code_map[name] = re.sub(old, new, code_map[name])
@ -155,9 +154,9 @@ class ShaderWrapper(object):
# Changing context # Changing context
def use_clip_plane(self): def use_clip_plane(self):
if "clip_plane" not in self.uniforms: if "clip_plane" not in self.mobject_uniforms:
return False return False
return any(self.uniforms["clip_plane"]) return any(self.mobject_uniforms["clip_plane"])
def set_ctx_depth_test(self, enable: bool = True) -> None: def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable: if enable:
@ -222,18 +221,11 @@ class ShaderWrapper(object):
assert(self.vao is not None) assert(self.vao is not None)
self.vao.render() self.vao.render()
def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False): def update_program_uniforms(self, camera_uniforms: UniformDict):
if self.program is None: if self.program is None:
return return
for name, value in uniforms.items(): for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()):
if name not in self.program: set_program_uniform(self.program, name, value)
continue
if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value)
if universal and self.uniforms.get(name, None) == value:
continue
self.program[name].value = value
self.uniforms[name] = value
def get_vertex_buffer_object(self, refresh: bool = True): def get_vertex_buffer_object(self, refresh: bool = True):
if refresh: if refresh:

View File

@ -9,7 +9,6 @@ import numpy as np
from manimlib.config import parse_cli from manimlib.config import parse_cli
from manimlib.config import get_configuration from manimlib.config import get_configuration
from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_shader_dir from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file from manimlib.utils.file_ops import find_file
@ -17,11 +16,13 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence, Optional, Tuple from typing import Sequence, Optional, Tuple
from manimlib.typing import UniformDict
from moderngl.vertex_array import VertexArray from moderngl.vertex_array import VertexArray
from moderngl.framebuffer import Framebuffer from moderngl.framebuffer import Framebuffer
ID_TO_TEXTURE: dict[int, moderngl.Texture] = dict() ID_TO_TEXTURE: dict[int, moderngl.Texture] = dict()
PROGRAM_UNIFORM_MIRRORS: dict[int, dict[str, float | tuple]] = dict()
@lru_cache() @lru_cache()
@ -63,6 +64,38 @@ def get_shader_program(
) )
def set_program_uniform(
program: moderngl.Program,
name: str,
value: float | tuple | np.ndarray
) -> bool:
"""
Sets a program uniform, and also keeps track of a dictionary
of previously set uniforms for that program so that it
doesn't needlessly reset it, requiring an exchange with gpu
memory, if it sees the same value again.
Returns True if changed the program, False if it left it as is.
"""
pid = id(program)
if pid not in PROGRAM_UNIFORM_MIRRORS:
PROGRAM_UNIFORM_MIRRORS[pid] = dict()
uniform_mirror = PROGRAM_UNIFORM_MIRRORS[pid]
if type(value) is np.ndarray and value.ndim > 0:
value = tuple(value)
if uniform_mirror.get(name, None) == value:
return False
try:
program[name].value = value
except KeyError:
return False
uniform_mirror[name] = value
return True
@lru_cache() @lru_cache()
def get_shader_code_from_file(filename: str) -> str | None: def get_shader_code_from_file(filename: str) -> str | None:
if not filename: if not filename:

View File

@ -29,6 +29,7 @@ class Window(PygletWindow):
size: tuple[int, int] = (1280, 720), size: tuple[int, int] = (1280, 720),
samples = 0 samples = 0
): ):
scene.window = self
super().__init__(size=size, samples=samples) super().__init__(size=size, samples=samples)
self.default_size = size self.default_size = size