diff --git a/manimlib/animation/animation.py b/manimlib/animation/animation.py index 310577c8..12964670 100644 --- a/manimlib/animation/animation.py +++ b/manimlib/animation/animation.py @@ -1,6 +1,6 @@ from copy import deepcopy -from manimlib.mobject.mobject import Mobject +from manimlib.mobject.mobject import Mobject, _AnimationBuilder from manimlib.utils.config_ops import digest_config from manimlib.utils.rate_functions import smooth from manimlib.utils.simple_functions import clip @@ -159,3 +159,13 @@ class Animation(object): def is_remover(self): return self.remover + + +def prepare_animation(anim): + if isinstance(anim, _AnimationBuilder): + return anim.build() + + if isinstance(anim, Animation): + return anim + + raise TypeError(f"Object {anim} cannot be converted to an animation") diff --git a/manimlib/animation/composition.py b/manimlib/animation/composition.py index c120f60b..ba175dce 100644 --- a/manimlib/animation/composition.py +++ b/manimlib/animation/composition.py @@ -1,6 +1,6 @@ import numpy as np -from manimlib.animation.animation import Animation +from manimlib.animation.animation import Animation, prepare_animation from manimlib.mobject.mobject import Group from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate @@ -29,7 +29,7 @@ class AnimationGroup(Animation): def __init__(self, *animations, **kwargs): digest_config(self, kwargs) - self.animations = animations + self.animations = [prepare_animation(anim) for anim in animations] if self.group is None: self.group = Group(*remove_list_redundancies( [anim.mobject for anim in animations] diff --git a/manimlib/animation/transform.py b/manimlib/animation/transform.py index 257feb17..d76047ff 100644 --- a/manimlib/animation/transform.py +++ b/manimlib/animation/transform.py @@ -149,6 +149,12 @@ class MoveToTarget(Transform): ) +class _MethodAnimation(MoveToTarget): + def __init__(self, mobject, methods): + self.methods = methods + super().__init__(mobject) + + class ApplyMethod(Transform): def __init__(self, method, *args, **kwargs): """ diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 9aede1af..7cc437c4 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -80,6 +80,11 @@ class Mobject(object): if self.depth_test: self.apply_depth_test() + @property + def animate(self): + # Borrowed from https://github.com/ManimCommunity/manim/ + return _AnimationBuilder(self) + def __str__(self): return self.__class__.__name__ @@ -1571,3 +1576,51 @@ class Point(Mobject): def set_location(self, new_loc): self.set_points(np.array(new_loc, ndmin=2, dtype=float)) + + +class _AnimationBuilder: + def __init__(self, mobject): + self.mobject = mobject + self.overridden_animation = None + self.mobject.generate_target() + self.is_chaining = False + self.methods = [] + + def __getattr__(self, method_name): + method = getattr(self.mobject.target, method_name) + self.methods.append(method) + has_overridden_animation = hasattr(method, "_override_animate") + + if (self.is_chaining and has_overridden_animation) or self.overridden_animation: + raise NotImplementedError( + "Method chaining is currently not supported for " + "overridden animations" + ) + + def update_target(*method_args, **method_kwargs): + if has_overridden_animation: + self.overridden_animation = method._override_animate( + self.mobject, *method_args, **method_kwargs + ) + else: + method(*method_args, **method_kwargs) + return self + + self.is_chaining = True + return update_target + + def build(self): + from manimlib.animation.transform import _MethodAnimation + + if self.overridden_animation: + return self.overridden_animation + + return _MethodAnimation(self.mobject, self.methods) + + +def override_animate(method): + def decorator(animation_method): + method._override_animate = animation_method + return animation_method + + return decorator diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index f6e2da31..ac3915fa 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -10,7 +10,7 @@ import numpy as np import time from IPython.terminal.embed import InteractiveShellEmbed -from manimlib.animation.animation import Animation +from manimlib.animation.animation import prepare_animation from manimlib.animation.transform import MoveToTarget from manimlib.mobject.mobject import Point from manimlib.camera.camera import Camera @@ -348,10 +348,7 @@ class Scene(object): state["method_args"] = [] for arg in args: - if isinstance(arg, Animation): - compile_method(state) - animations.append(arg) - elif inspect.ismethod(arg): + if inspect.ismethod(arg): compile_method(state) state["curr_method"] = arg elif state["curr_method"] is not None: @@ -362,7 +359,13 @@ class Scene(object): you meant to pass in as a Scene.play argument """) else: - raise Exception("Invalid play arguments") + try: + anim = prepare_animation(arg) + except TypeError: + raise TypeError(f"Unexpected argument {arg} passed to Scene.play()") + + compile_method(state) + animations.append(anim) compile_method(state) for animation in animations: