Replacing apply_along_axis

This commit is contained in:
Grant Sanderson
2020-02-19 23:43:33 -08:00
parent 0176bda584
commit c794039b9d
4 changed files with 6 additions and 7 deletions

View File

@ -54,7 +54,7 @@ class ParametricFunction(VMobject):
full_t_range += list(np.linspace(s1, s2, n_inserts + 1)[1:]) full_t_range += list(np.linspace(s1, s2, n_inserts + 1)[1:])
points = np.array([self.function(t) for t in full_t_range]) points = np.array([self.function(t) for t in full_t_range])
valid_indices = np.apply_along_axis(np.all, 1, np.isfinite(points)) valid_indices = np.isfinite(points).all(1)
points = points[valid_indices] points = points[valid_indices]
if len(points) > 0: if len(points) > 0:
self.start_new_path(points[0]) self.start_new_path(points[0])

View File

@ -298,7 +298,7 @@ class Mobject(Container):
if len(kwargs) == 0: if len(kwargs) == 0:
kwargs["about_point"] = ORIGIN kwargs["about_point"] = ORIGIN
self.apply_points_function_about_point( self.apply_points_function_about_point(
lambda points: np.apply_along_axis(function, 1, points), lambda points: np.array([function(p) for p in points]),
**kwargs **kwargs
) )
return self return self
@ -775,7 +775,7 @@ class Mobject(Container):
return self.get_bounding_box_point(np.zeros(self.dim)) return self.get_bounding_box_point(np.zeros(self.dim))
def get_center_of_mass(self): def get_center_of_mass(self):
return np.apply_along_axis(np.mean, 0, self.get_all_points()) return self.get_all_points().mean(0)
def get_boundary_point(self, direction): def get_boundary_point(self, direction):
all_points = self.get_points_defining_boundary() all_points = self.get_points_defining_boundary()

View File

@ -645,8 +645,8 @@ class VMobject(Mobject):
for a in np.linspace(0, 1, n_sample_points) for a in np.linspace(0, 1, n_sample_points)
]) ])
diffs = points[1:] - points[:-1] diffs = points[1:] - points[:-1]
norms = np.apply_along_axis(get_norm, 1, diffs) norms = np.array([get_norm(d) for d in diffs])
return np.sum(norms) return norms.sum()
# Alignment # Alignment
def align_points(self, vmobject): def align_points(self, vmobject):

View File

@ -84,8 +84,7 @@ def interpolate_color(color1, color2, alpha):
def average_color(*colors): def average_color(*colors):
rgbs = np.array(list(map(color_to_rgb, colors))) rgbs = np.array(list(map(color_to_rgb, colors)))
mean_rgb = np.apply_along_axis(np.mean, 0, rgbs) return rgb_to_color(rgbs.mean(0))
return rgb_to_color(mean_rgb)
def random_bright_color(): def random_bright_color():