Updates to copying based on pickle serializing

This commit is contained in:
Grant Sanderson
2022-04-21 14:32:27 -07:00
parent c04615c4e9
commit fe3e10acd2

View File

@ -462,32 +462,21 @@ class Mobject(object):
self.assemble_family()
return self
# Creating new Mobjects from this one
# Copying and serialization
def replicate(self, n: int) -> Group:
return self.get_group_class()(
*(self.copy() for x in range(n))
)
def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs):
"""
Returns a new mobject containing multiple copies of this one
arranged in a grid
"""
grid = self.replicate(n_rows * n_cols)
grid.arrange_in_grid(n_rows, n_cols, **kwargs)
if height is not None:
grid.set_height(height)
return grid
# Copying
def serialize(self):
pre, self.parents = self.parents, []
result = pickle.dumps(self)
self.parents = pre
return result
def copy(self):
self.parents = []
try:
return pickle.loads(pickle.dumps(self))
serial = self.serialize()
return pickle.loads(serial)
except AttributeError:
return copy.deepcopy(self)
return result
def deepcopy(self):
# This used to be different from copy, so is now just here for backward compatibility
@ -513,7 +502,7 @@ class Mobject(object):
self.become(self.saved_state)
return self
def save_to_file(self, file_path):
def save_to_file(self, file_path: str):
if not file_path.endswith(".mob"):
file_path += ".mob"
if os.path.exists(file_path):
@ -521,7 +510,7 @@ class Mobject(object):
if cont != "y":
return
with open(file_path, "wb") as fp:
pickle.dump(self, fp)
fp.write(self.serialize())
log.info(f"Saved mobject to {file_path}")
return self
@ -534,6 +523,41 @@ class Mobject(object):
mobject = pickle.load(fp)
return mobject
def become(self, mobject: Mobject):
"""
Edit all data and submobjects to be idential
to another mobject
"""
self.align_family(mobject)
for sm1, sm2 in zip(self.get_family(), mobject.get_family()):
sm1.set_data(sm2.data)
sm1.set_uniforms(sm2.uniforms)
sm1.shader_folder = sm2.shader_folder
sm1.texture_paths = sm2.texture_paths
sm1.depth_test = sm2.depth_test
sm1.render_primitive = sm2.render_primitive
self.refresh_shader_wrapper_id()
self.refresh_bounding_box(recurse_down=True)
return self
# Creating new Mobjects from this one
def replicate(self, n: int) -> Group:
serial = self.serialize()
group_class = self.get_group_class()
return group_class(*(pickle.loads(serial) for _ in range(n)))
def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs) -> Group:
"""
Returns a new mobject containing multiple copies of this one
arranged in a grid
"""
grid = self.replicate(n_rows * n_cols)
grid.arrange_in_grid(n_rows, n_cols, **kwargs)
if height is not None:
grid.set_height(height)
return grid
# Updating
def init_updaters(self):
@ -1521,18 +1545,6 @@ class Mobject(object):
"""
pass # To implement in subclass
def become(self, mobject: Mobject):
"""
Edit all data and submobjects to be idential
to another mobject
"""
self.align_family(mobject)
for sm1, sm2 in zip(self.get_family(), mobject.get_family()):
sm1.set_data(sm2.data)
sm1.set_uniforms(sm2.uniforms)
self.refresh_bounding_box(recurse_down=True)
return self
# Locking data
def lock_data(self, keys: Iterable[str]):