Interpolation for vectorized mobjects implemented

This commit is contained in:
Grant Sanderson
2016-04-10 12:34:28 -07:00
parent 330b8870ba
commit 26c5aa8e67
12 changed files with 210 additions and 159 deletions

View File

@ -18,7 +18,7 @@ class Mobject(object):
#Number of numbers used to describe a point (3 for pos, 3 for normal vector)
CONFIG = {
"color" : WHITE,
"point_thickness" : DEFAULT_POINT_THICKNESS,
"stroke_width" : DEFAULT_POINT_THICKNESS,
"name" : None,
"display_mode" : "points", #TODO, REMOVE
"dim" : 3,
@ -109,7 +109,7 @@ class Mobject(object):
def rotate(self, angle, axis = OUT, axes = []):
if len(axes) == 0:
axes = [axis]
rot_matrix = np.identity(self.DIM)
rot_matrix = np.identity(self.dim)
for axis in axes:
rot_matrix = np.dot(rot_matrix, rotation_matrix(angle, axis))
t_rot_matrix = np.transpose(rot_matrix)
@ -135,7 +135,7 @@ class Mobject(object):
alphas = alphas**wag_factor
mob.points += np.dot(
alphas.reshape((len(alphas), 1)),
np.array(direction).reshape((1, mob.DIM))
np.array(direction).reshape((1, mob.dim))
)
return self
@ -310,7 +310,7 @@ class Mobject(object):
return 0
def get_merged_array(self, array_attr):
result = np.zeros((0, self.DIM))
result = np.zeros((0, self.dim))
for mob in self.nonempty_family_members():
result = np.append(result, getattr(mob, array_attr), 0)
return result
@ -327,7 +327,7 @@ class Mobject(object):
return len(self.points)
def get_critical_point(self, direction):
result = np.zeros(self.DIM)
result = np.zeros(self.dim)
for dim in [0, 1]:
if direction[dim] <= 0:
min_point = self.reduce_across_dimension(np.min, np.min, dim)
@ -350,7 +350,7 @@ class Mobject(object):
return self.get_critical_point(direction)
def get_center(self):
return self.get_critical_point(np.zeros(self.DIM))
return self.get_critical_point(np.zeros(self.dim))
def get_center_of_mass(self):
return np.apply_along_axis(np.mean, 0, self.get_all_points())
@ -404,52 +404,52 @@ class Mobject(object):
self.submobject_family()
)
## Alignment
@staticmethod
def align_data(mobject1, mobject2):
count1 = len(mobject1.points)
count2 = len(mobject2.points)
if count1 != count2:
if count1 < count2:
smaller = mobject1
target_size = count2
else:
smaller = mobject2
target_size = count1
if len(smaller.points) == 0:
smaller.add_points(
[np.zeros(smaller.DIM)],
color = BLACK
)
smaller.apply_over_attr_arrays(
lambda a : streth_array_to_length(a, target_size)
)
## Alignment
def align_data(self, mobject):
self.align_points(mobject)
#Recurse
diff = len(mobject1.sub_mobjects) - len(mobject2.sub_mobjects)
if diff < 0:
larger, smaller = mobject2, mobject1
elif diff > 0:
larger, smaller = mobject1, mobject2
if diff != 0:
diff = len(self.sub_mobjects) - len(mobject.sub_mobjects)
if diff != 0:
if diff < 0:
larger, smaller = mobject, self
elif diff > 0:
larger, smaller = self, mobject
for sub_mob in larger.sub_mobjects[-abs(diff):]:
smaller.add(Point(sub_mob.get_center()))
for m1, m2 in zip(mobject1.sub_mobjects, mobject2.sub_mobjects):
Mobject.align_data(m1, m2)
smaller.add(sub_mob.get_point_mobject())
for m1, m2 in zip(self.sub_mobjects, mobject.sub_mobjects):
m1.align_data(m2)
def interpolate(self, mobject1, mobject2, alpha):
def get_point_mobject(self):
"""
The simplest mobject to be transformed to or from self.
Should by a point of the appropriate type
"""
raise Exception("Not implemented")
def align_points(self, mobject):
count1 = self.get_num_points()
count2 = mobject.get_num_points()
if count1 < count2:
self.align_points_with_larger(mobject)
elif count2 < count1:
mobject.align_points_with_larger(self)
return self
def align_points_with_larger(self, larger_mobject):
raise Exception("Not implemented")
def interpolate(self, mobject1, mobject2, alpha, path_func):
"""
Turns target_mobject into an interpolation between mobject1
and mobject2.
"""
#TODO
Mobject.align_data(mobject1, mobject2)
for attr in self.get_array_attrs():
setattr(self, attr, interpolate(
getattr(mobject1, attr),
getattr(mobject2, attr),
alpha))
self.points = path_func(
mobject1.points, mobject2.points, alpha
)
self.interpolate_color(mobject1, mobject2, alpha)
def interpolate_color(self, mobject1, mobject2, alpha):
raise Exception("Not implemented")