Starting to vectorize

This commit is contained in:
Grant Sanderson
2016-04-09 20:03:57 -07:00
parent 8414022b81
commit 330b8870ba
16 changed files with 481 additions and 5304 deletions

View File

@ -8,76 +8,51 @@ from colour import Color
from helpers import *
#TODO: Explain array_attrs
class Mobject(object):
"""
Mathematical Object
"""
#Number of numbers used to describe a point (3 for pos, 3 for normal vector)
CONFIG = {
"color" : WHITE,
"point_thickness" : DEFAULT_POINT_THICKNESS,
"name" : None,
"display_mode" : "points"
"display_mode" : "points", #TODO, REMOVE
"dim" : 3,
}
DIM = 3
def __init__(self, *sub_mobjects, **kwargs):
digest_config(self, kwargs)
self.sub_mobjects = list(sub_mobjects)
self.color = Color(self.color)
if self.name is None:
self.name = self.__class__.__name__
self.has_normals = hasattr(self, 'unit_normal')
self.init_points()
self.init_colors()
self.generate_points()
if self.has_normals:
self.unit_normals = np.apply_along_axis(
self.unit_normal,
1,
self.points,
)
def __str__(self):
return self.name
def init_points(self):
for attr in self.get_array_attrs():
setattr(self, attr, np.zeros((0, 3)))
self.points = np.zeros((0, self.dim))
def init_colors(self):
#For subclasses
pass
def generate_points(self):
#Typically implemented in subclass, unless purposefully left blank
pass
def add_points(self, points, rgbs = None, color = None):
"""
points must be a Nx3 numpy array, as must rgbs if it is not None
"""
if not isinstance(points, np.ndarray):
points = np.array(points)
num_new_points = points.shape[0]
self.points = np.append(self.points, points, axis = 0)
if rgbs is None:
color = Color(color) if color else self.color
rgbs = np.array([color.get_rgb()] * num_new_points)
elif rgbs.shape != points.shape:
raise Exception("points and rgbs must have same shape")
self.rgbs = np.append(self.rgbs, rgbs, axis = 0)
if self.has_normals:
self.unit_normals = np.append(
self.unit_normals,
np.apply_along_axis(self.unit_normal, 1, points),
axis = 0
)
return self
def add(self, *mobjects):
self.sub_mobjects = list_update(self.sub_mobjects, mobjects)
return self
def get_array_attrs(self):
result = ["points", "rgbs"]
if self.has_normals:
result.append("unit_normals")
return result
return ["points"]
def digest_mobject_attrs(self):
"""
@ -91,33 +66,40 @@ class Mobject(object):
self.sub_mobjects = list_update(self.sub_mobjects, mobject_attrs)
return self
def apply_over_attr_arrays(self, func):
for attr in self.get_array_attrs():
setattr(self, attr, func(getattr(self, attr)))
return self
def show(self):
def get_image(self):
from camera import Camera
camera = Camera()
camera.capture_mobject(self)
Image.fromarray(camera.get_image()).show()
return Image.fromarray(camera.get_image())
def show(self):
self.get_image().show()
def save_image(self, name = None):
Image.fromarray(disp.paint_mobject(self)).save(
self.get_image().save(
os.path.join(MOVIE_DIR, (name or str(self)) + ".png")
)
def copy(self):
return deepcopy(self)
#### Fundamental operations ######
#### Transforming operations ######
def apply_to_family(self, func):
for mob in self.nonempty_family_members():
func(mob)
def shift(self, *vectors):
total_vector = reduce(op.add, vectors)
for mob in self.nonempty_family_members():
mob.points += total_vector
return self
mob.points += total_vector
return self
def scale(self, scale_factor):
for mob in self.nonempty_family_members():
@ -133,8 +115,6 @@ class Mobject(object):
t_rot_matrix = np.transpose(rot_matrix)
for mob in self.nonempty_family_members():
mob.points = np.dot(mob.points, t_rot_matrix)
if mob.has_normals:
mob.unit_normals = np.dot(mob.unit_normals, t_rot_matrix)
return self
def stretch(self, factor, dim):
@ -159,68 +139,6 @@ class Mobject(object):
)
return self
def highlight(self, color = YELLOW_C, condition = None):
"""
Condition is function which takes in one arguments, (x, y, z).
"""
rgb = Color(color).get_rgb()
for mob in self.nonempty_family_members():
if condition:
to_change = np.apply_along_axis(condition, 1, mob.points)
mob.rgbs[to_change, :] = rgb
else:
mob.rgbs[:,:] = rgb
return self
def gradient_highlight(self, start_color, end_color):
start_rgb, end_rgb = [
np.array(Color(color).get_rgb())
for color in start_color, end_color
]
for mob in self.nonempty_family_members():
num_points = mob.get_num_points()
mob.rgbs = np.array([
interpolate(start_rgb, end_rgb, alpha)
for alpha in np.arange(num_points)/float(num_points)
])
return self
def match_colors(self, mobject):
Mobject.align_data(self, mobject)
self.rgbs = np.array(mobject.rgbs)
return self
def filter_out(self, condition):
for mob in self.nonempty_family_members():
to_eliminate = ~np.apply_along_axis(condition, 1, mob.points)
mob.points = mob.points[to_eliminate]
mob.rgbs = mob.rgbs[to_eliminate]
return self
def thin_out(self, factor = 5):
"""
Removes all but every nth point for n = factor
"""
for mob in self.nonempty_family_members():
num_points = self.get_num_points()
mob.apply_over_attr_arrays(
lambda arr : arr[
np.arange(0, num_points, factor)
]
)
return self
def sort_points(self, function = lambda p : p[0]):
"""
function is any map from R^3 to R
"""
for mob in self.nonempty_family_members():
indices = np.argsort(
np.apply_along_axis(function, 1, mob.points)
)
mob.apply_over_attr_arrays(lambda arr : arr[indices])
return self
def reverse_points(self):
for mob in self.nonempty_family_members():
mob.apply_over_attr_arrays(
@ -228,7 +146,6 @@ class Mobject(object):
)
return self
def repeat(self, count):
"""
This can make transition animations nicer
@ -336,11 +253,16 @@ class Mobject(object):
self.shift(start-self.points[0])
return self
## Color functions
def apply_complex_function(self, function):
return self.apply_function(
lambda (x, y, z) : complex_to_R3(function(complex(x, y)))
)
def highlight(self, color = YELLOW_C, condition = None):
"""
Condition is function which takes in one arguments, (x, y, z).
"""
raise Exception("Not implemented")
def gradient_highlight(self, start_color, end_color):
raise Exception("Not implemented")
def set_color(self, color):
self.highlight(color)
@ -352,15 +274,27 @@ class Mobject(object):
return self
def fade_to(self, color, alpha):
self.rgbs = interpolate(self.rgbs, np.array(Color(color).rgb), alpha)
for mob in self.sub_mobjects:
mob.fade_to(color, alpha)
start = color_to_rgb(self.get_color())
end = color_to_rgb(color)
new_rgb = interpolate(start, end, alpha)
for mob in self.nonempty_family_members():
mob.highlight(Color(rgb = new_rgb))
return self
def fade(self, darkness = 0.5):
self.fade_to(BLACK, darkness)
return self
def get_color(self):
return self.color
##
def apply_complex_function(self, function):
return self.apply_function(
lambda (x, y, z) : complex_to_R3(function(complex(x, y)))
)
def reduce_across_dimension(self, points_func, reduce_func, dim):
try:
values = [points_func(self.points[:, dim])]
@ -384,32 +318,6 @@ class Mobject(object):
def get_all_points(self):
return self.get_merged_array("points")
def get_all_rgbs(self):
return self.get_merged_array("rgbs")
def ingest_sub_mobjects(self):
attrs = self.get_array_attrs()
arrays = map(self.get_merged_array, attrs)
for attr, array in zip(attrs, arrays):
setattr(self, attr, array)
self.sub_mobjects = []
return self
def split(self):
result = [self] if len(self.points) > 0 else []
return result + self.sub_mobjects
def submobject_family(self):
sub_families = map(Mobject.submobject_family, self.sub_mobjects)
all_mobjects = [self] + reduce(op.add, sub_families, [])
return remove_list_redundancies(all_mobjects)
def nonempty_family_members(self):
return filter(
lambda m : m.get_num_points() > 0,
self.submobject_family()
)
### Getters ###
def get_num_points(self, including_submobjects = False):
@ -476,13 +384,27 @@ class Mobject(object):
return self.length_over_dim(1)
def point_from_proportion(self, alpha):
index = alpha*(self.get_num_points()-1)
return self.points[index]
raise Exception("Not implemented")
def get_color(self):
color = Color()
color.set_rgb(self.rgbs[0, :])
return color
## Family matters
def split(self):
result = [self] if len(self.points) > 0 else []
return result + self.sub_mobjects
def submobject_family(self):
sub_families = map(Mobject.submobject_family, self.sub_mobjects)
all_mobjects = [self] + reduce(op.add, sub_families, [])
return remove_list_redundancies(all_mobjects)
def nonempty_family_members(self):
return filter(
lambda m : m.get_num_points() > 0,
self.submobject_family()
)
## Alignment
@staticmethod
def align_data(mobject1, mobject2):
@ -527,52 +449,7 @@ class Mobject(object):
setattr(self, attr, interpolate(
getattr(mobject1, attr),
getattr(mobject2, attr),
alpha))
class Point(Mobject):
CONFIG = {
"color" : BLACK,
}
def __init__(self, location = ORIGIN, **kwargs):
digest_locals(self)
Mobject.__init__(self, **kwargs)
def generate_points(self):
self.add_points([self.location])
#TODO, Make the two implementations bellow non-redundant
class Mobject1D(Mobject):
CONFIG = {
"density" : DEFAULT_POINT_DENSITY_1D,
}
def __init__(self, **kwargs):
digest_config(self, kwargs)
self.epsilon = 1.0 / self.density
Mobject.__init__(self, **kwargs)
def add_line(self, start, end, color = None):
start, end = map(np.array, [start, end])
length = np.linalg.norm(end - start)
if length == 0:
points = [start]
else:
epsilon = self.epsilon/length
points = [
interpolate(start, end, t)
for t in np.arange(0, 1, epsilon)
]
self.add_points(points, color = color)
class Mobject2D(Mobject):
CONFIG = {
"density" : DEFAULT_POINT_DENSITY_2D,
}
def __init__(self, **kwargs):
digest_config(self, kwargs)
self.epsilon = 1.0 / self.density
Mobject.__init__(self, **kwargs)
alpha))