Vectorize all the things

This commit is contained in:
Grant Sanderson
2016-04-17 00:31:38 -07:00
parent bd3783586a
commit 0d4e928b6e
12 changed files with 406 additions and 362 deletions

View File

@ -23,9 +23,9 @@ class Mobject(object):
"dim" : 3,
"target" : None
}
def __init__(self, *sub_mobjects, **kwargs):
def __init__(self, *submobjects, **kwargs):
digest_config(self, kwargs)
self.sub_mobjects = list(sub_mobjects)
self.submobjects = list(submobjects)
self.color = Color(self.color)
if self.name is None:
self.name = self.__class__.__name__
@ -48,7 +48,7 @@ class Mobject(object):
pass
def add(self, *mobjects):
self.sub_mobjects = list_update(self.sub_mobjects, mobjects)
self.submobjects = list_update(self.submobjects, mobjects)
return self
def get_array_attrs(self):
@ -57,13 +57,13 @@ class Mobject(object):
def digest_mobject_attrs(self):
"""
Ensures all attributes which are mobjects are included
in the sub_mobjects list.
in the submobjects list.
"""
mobject_attrs = filter(
lambda x : isinstance(x, Mobject),
self.__dict__.values()
)
self.sub_mobjects = list_update(self.sub_mobjects, mobject_attrs)
self.submobjects = list_update(self.submobjects, mobject_attrs)
return self
def apply_over_attr_arrays(self, func):
@ -229,7 +229,7 @@ class Mobject(object):
return self.scale(height/self.get_height())
def replace(self, mobject, stretch = False):
if not mobject.get_num_points() and not mobject.sub_mobjects:
if not mobject.get_num_points() and not mobject.submobjects:
raise Warning("Attempting to replace mobject with no points")
return self
if stretch:
@ -302,7 +302,7 @@ class Mobject(object):
values = []
values += [
mob.reduce_across_dimension(points_func, reduce_func, dim)
for mob in self.sub_mobjects
for mob in self.submobjects
]
try:
return reduce_func(values)
@ -388,10 +388,10 @@ class Mobject(object):
def split(self):
result = [self] if len(self.points) > 0 else []
return result + self.sub_mobjects
return result + self.submobjects
def submobject_family(self):
sub_families = map(Mobject.submobject_family, self.sub_mobjects)
sub_families = map(Mobject.submobject_family, self.submobjects)
all_mobjects = [self] + reduce(op.add, sub_families, [])
return remove_list_redundancies(all_mobjects)
@ -403,10 +403,10 @@ class Mobject(object):
## Alignment
def align_data(self, mobject):
self.align_sub_mobjects(mobject)
self.align_submobjects(mobject)
self.align_points(mobject)
#Recurse
for m1, m2 in zip(self.sub_mobjects, mobject.sub_mobjects):
for m1, m2 in zip(self.submobjects, mobject.submobjects):
m1.align_data(m2)
def get_point_mobject(self, center = None):
@ -429,7 +429,7 @@ class Mobject(object):
def align_points_with_larger(self, larger_mobject):
raise Exception("Not implemented")
def align_sub_mobjects(self, mobject):
def align_submobjects(self, mobject):
#If one is empty, and the other is not,
#push it into its submobject list
self_has_points, mob_has_points = [
@ -437,30 +437,31 @@ class Mobject(object):
for mob in self, mobject
]
if self_has_points and not mob_has_points:
self.push_self_into_sub_mobjects()
self.push_self_into_submobjects()
elif mob_has_points and not self_has_points:
mob.push_self_into_sub_mobjects()
self_count = len(self.sub_mobjects)
mob_count = len(mobject.sub_mobjects)
mob.push_self_into_submobjects()
self_count = len(self.submobjects)
mob_count = len(mobject.submobjects)
diff = abs(self_count-mob_count)
if self_count < mob_count:
self.add_n_more_sub_mobjects(diff)
self.add_n_more_submobjects(diff)
elif mob_count < self_count:
mobject.add_n_more_sub_mobjects(diff)
mobject.add_n_more_submobjects(diff)
return self
def push_self_into_sub_mobjects(self):
def push_self_into_submobjects(self):
copy = self.copy()
copy.sub_mobjects = []
copy.submobjects = []
self.points = np.zeros((0, self.dim))
self.add(copy)
return self
def add_n_more_sub_mobjects(self, n):
if n > 0 and len(self.sub_mobjects) == 0:
def add_n_more_submobjects(self, n):
if n > 0 and len(self.submobjects) == 0:
self.add(self.copy())
n = n-1
for i in range(n):
self.add(self.sub_mobjects[i].copy())
self.add(self.submobjects[i].copy())
return self
def interpolate(self, mobject1, mobject2, alpha, path_func):