diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 3caeda69..5649e616 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -464,40 +464,76 @@ class Mobject(object): # Copying and serialization + def stash_mobject_pointers(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + uncopied_attrs = ["parents", "target", "saved_state"] + stash = dict() + for attr in uncopied_attrs: + if hasattr(self, attr): + value = getattr(self, attr) + stash[attr] = value + null_value = [] if isinstance(value, Iterable) else None + setattr(self, attr, null_value) + result = func(self, *args, **kwargs) + self.__dict__.update(stash) + return result + return wrapper + + @stash_mobject_pointers def serialize(self): - pre, self.parents = self.parents, [] - result = pickle.dumps(self) - self.parents = pre - return result + return pickle.dumps(self) def deserialize(self, data: bytes): self.become(pickle.loads(data)) return self - def copy(self): - try: - serial = self.serialize() - return pickle.loads(serial) - except AttributeError: - return copy.deepcopy(self) + @stash_mobject_pointers + def copy(self, deep: bool = False): + if deep: + try: + # Often faster than deepcopy + return pickle.loads(self.serialize()) + except AttributeError: + return copy.deepcopy(self) + + result = copy.copy(self) + + # The line above is only a shallow copy, so the internal + # data which are numpyu arrays or other mobjects still + # need to be further copied. + result.data = dict(self.data) + for key in result.data: + result.data[key] = result.data[key].copy() + + result.uniforms = dict(self.uniforms) + for key in result.uniforms: + if isinstance(result.uniforms[key], np.ndarray): + result.uniforms[key] = result.uniforms[key].copy() + + result.submobjects = [] + result.add(*(sm.copy() for sm in self.submobjects)) + result.match_updaters(self) + + family = self.get_family() + for attr, value in list(self.__dict__.items()): + if isinstance(value, Mobject) and value in family and value is not self: + setattr(result, attr, result.family[self.family.index(value)]) + if isinstance(value, np.ndarray): + setattr(result, attr, value.copy()) + if isinstance(value, ShaderWrapper): + setattr(result, attr, value.copy()) return result def deepcopy(self): - # This used to be different from copy, so is now just here for backward compatibility - return self.copy() + return self.copy(deep=True) def generate_target(self, use_deepcopy: bool = False): - # TODO, remove now pointless use_deepcopy arg - self.target = None # Prevent exponential explosion - self.target = self.copy() + self.target = self.copy(deep=use_deepcopy) return self.target def save_state(self, use_deepcopy: bool = False): - # TODO, remove now pointless use_deepcopy arg - if hasattr(self, "saved_state"): - # Prevent exponential growth of data - self.saved_state = None - self.saved_state = self.copy() + self.saved_state = self.copy(deep=use_deepcopy) return self def restore(self): @@ -540,9 +576,8 @@ class Mobject(object): # Creating new Mobjects from this one def replicate(self, n: int) -> Group: - serial = self.serialize() group_class = self.get_group_class() - return group_class(*(pickle.loads(serial) for _ in range(n))) + return group_class(*(self.copy() for _ in range(n))) def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs) -> Group: """