Rewrote ParametricFunction to have less buggy interpolation

This commit is contained in:
Grant Sanderson
2019-02-06 15:18:11 -08:00
parent 16e8a76c6a
commit 47f6d6ba38
8 changed files with 58 additions and 29 deletions

View File

@ -31,7 +31,6 @@ class Axes(VGroup):
"x_max": FRAME_X_RADIUS, "x_max": FRAME_X_RADIUS,
"y_min": -FRAME_Y_RADIUS, "y_min": -FRAME_Y_RADIUS,
"y_max": FRAME_Y_RADIUS, "y_max": FRAME_Y_RADIUS,
"default_num_graph_points": 100,
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -68,14 +67,12 @@ class Axes(VGroup):
]) ])
def get_graph( def get_graph(
self, function, num_graph_points=None, self, function,
x_min=None, x_min=None,
x_max=None, x_max=None,
**kwargs **kwargs
): ):
kwargs["fill_opacity"] = kwargs.get("fill_opacity", 0) kwargs["fill_opacity"] = kwargs.get("fill_opacity", 0)
kwargs["num_anchor_points"] = \
num_graph_points or self.default_num_graph_points
x_min = x_min or self.x_min x_min = x_min or self.x_min
x_max = x_max or self.x_max x_max = x_max or self.x_max
graph = ParametricFunction( graph = ParametricFunction(

View File

@ -7,22 +7,54 @@ class ParametricFunction(VMobject):
CONFIG = { CONFIG = {
"t_min": 0, "t_min": 0,
"t_max": 1, "t_max": 1,
"num_anchor_points": 100, # TODO, be smarter about choosing this number
"step_size": 0.01,
"dt": 1e-8,
# TODO, be smar about figuring these out?
"discontinuities": [],
} }
def __init__(self, function, **kwargs): def __init__(self, function, **kwargs):
self.function = function self.function = function
VMobject.__init__(self, **kwargs) VMobject.__init__(self, **kwargs)
def get_function(self):
return self.function
def get_point_from_function(self, t):
return self.function(t)
def generate_points(self): def generate_points(self):
n_points = 3 * self.num_anchor_points - 2 t_min, t_max = self.t_min, self.t_max
self.points = np.zeros((n_points, self.dim)) dt = self.dt
self.points[:, 0] = np.linspace( step_size = self.step_size
self.t_min, self.t_max, n_points
discontinuities = filter(
lambda t: t_min <= t <= t_max,
self.discontinuities
) )
# VMobject.apply_function takes care of preserving discontinuities = np.array(list(discontinuities))
# desirable tangent line properties at anchor points boundary_times = [
self.apply_function(lambda p: self.function(p[0])) self.t_min, self.t_max,
*(discontinuities - dt),
*(discontinuities + dt),
]
boundary_times.sort()
print(boundary_times)
for t1, t2 in zip(boundary_times[0::2], boundary_times[1::2]):
t_range = list(np.arange(t1, t2, step_size))
if t_range[-1] != t2:
t_range.append(t2)
points = np.array([self.function(t) for t in t_range])
valid_indices = np.apply_along_axis(
np.all, 1, np.isfinite(points)
)
points = points[valid_indices]
if len(points) > 0:
self.start_new_path(points[0])
self.add_points_as_corners(points[1:])
self.make_smooth()
return self
class FunctionGraph(ParametricFunction): class FunctionGraph(ParametricFunction):
@ -34,12 +66,11 @@ class FunctionGraph(ParametricFunction):
def __init__(self, function, **kwargs): def __init__(self, function, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
self.parametric_function = \
def parametric_function(t): lambda t: np.array([t, function(t), 0])
return t * RIGHT + function(t) * UP
ParametricFunction.__init__( ParametricFunction.__init__(
self, self,
parametric_function, self.parametric_function,
t_min=self.x_min, t_min=self.x_min,
t_max=self.x_max, t_max=self.x_max,
**kwargs **kwargs
@ -48,3 +79,6 @@ class FunctionGraph(ParametricFunction):
def get_function(self): def get_function(self):
return self.function return self.function
def get_point_from_function(self, x):
return self.parametric_function(x)

View File

@ -485,11 +485,12 @@ class VMobject(Mobject):
return points return points
def set_points_as_corners(self, points): def set_points_as_corners(self, points):
if len(points) == 0: nppcc = self.n_points_per_cubic_curve
return # TODO, raise warning? points = np.array(points)
self.clear_points() self.set_anchors_and_handles(*[
self.start_new_path(points[0]) interpolate(points[:-1], points[1:], a)
self.add_points_as_corners(points[1:]) for a in np.linspace(0, 1, nppcc)
])
return self return self
def set_points_smoothly(self, points): def set_points_smoothly(self, points):

View File

@ -44,7 +44,6 @@ class GraphScene(Scene):
"axes_color": GREY, "axes_color": GREY,
"graph_origin": 2.5 * DOWN + 4 * LEFT, "graph_origin": 2.5 * DOWN + 4 * LEFT,
"exclude_zero_label": True, "exclude_zero_label": True,
"num_graph_anchor_points": 25,
"default_graph_colors": [BLUE, GREEN, YELLOW], "default_graph_colors": [BLUE, GREEN, YELLOW],
"default_derivative_color": GREEN, "default_derivative_color": GREEN,
"default_input_color": YELLOW, "default_input_color": YELLOW,
@ -149,6 +148,7 @@ class GraphScene(Scene):
color=None, color=None,
x_min=None, x_min=None,
x_max=None, x_max=None,
**kwargs
): ):
if color is None: if color is None:
color = next(self.default_graph_colors_cycle) color = next(self.default_graph_colors_cycle)
@ -167,7 +167,7 @@ class GraphScene(Scene):
graph = ParametricFunction( graph = ParametricFunction(
parameterized_function, parameterized_function,
color=color, color=color,
num_anchor_points=self.num_graph_anchor_points, **kwargs
) )
graph.underlying_function = func graph.underlying_function = func
return graph return graph

View File

@ -97,7 +97,6 @@ class ContrastAbstractAndConcrete(Scene):
ParametricFunction( ParametricFunction(
lambda t : (t/denom)*RIGHT+np.sin(t)*UP+np.cos(t)*OUT, lambda t : (t/denom)*RIGHT+np.sin(t)*UP+np.cos(t)*OUT,
t_max = 12*np.pi, t_max = 12*np.pi,
num_anchor_points = 100,
) )
for denom in (12.0, 4.0) for denom in (12.0, 4.0)
] ]

View File

@ -208,7 +208,6 @@ class DampenedSpring(Scene):
ParametricFunction( ParametricFunction(
lambda t : (t/denom)*RIGHT+np.sin(t)*UP+np.cos(t)*OUT, lambda t : (t/denom)*RIGHT+np.sin(t)*UP+np.cos(t)*OUT,
t_max = 12*np.pi, t_max = 12*np.pi,
num_anchor_points = 100,
color = GREY, color = GREY,
).shift(3*LEFT) ).shift(3*LEFT)
for denom in (12.0, 2.0) for denom in (12.0, 2.0)

View File

@ -3911,7 +3911,6 @@ class BoundsAtInfinity(SummarizeFormula):
number_line_config = { number_line_config = {
"include_tip" : False, "include_tip" : False,
}, },
default_num_graph_points = 1000,
) )
axes.x_axis.add_numbers(*list(filter( axes.x_axis.add_numbers(*list(filter(
lambda x : x != 0, lambda x : x != 0,

View File

@ -525,7 +525,7 @@ class ShowPlan(PiCreatureScene):
wave = FunctionGraph( wave = FunctionGraph(
lambda x : 0.3*np.sin(15*x)*np.sin(0.5*x), lambda x : 0.3*np.sin(15*x)*np.sin(0.5*x),
x_min = 0, x_max = 30, x_min = 0, x_max = 30,
num_anchor_points = 500, step_size = 0.001,
) )
wave.next_to(word, RIGHT) wave.next_to(word, RIGHT)
rect = BackgroundRectangle(wave, fill_opacity = 1) rect = BackgroundRectangle(wave, fill_opacity = 1)
@ -965,7 +965,7 @@ class VariousMusicalNotes(Scene):
a = graph_width_tracker.get_value() a = graph_width_tracker.get_value()
return FunctionGraph( return FunctionGraph(
lambda x : np.exp(-a*x**2)*np.sin(freq*x)-0.5, lambda x : np.exp(-a*x**2)*np.sin(freq*x)-0.5,
num_anchor_points = 500, step_size = 0.001,
) )
graph = get_graph() graph = get_graph()
def graph_update(graph): def graph_update(graph):
@ -1044,7 +1044,7 @@ class VariousMusicalNotes(Scene):
lambda x : 0.5*np.sin(freq*x), lambda x : 0.5*np.sin(freq*x),
x_min = -FRAME_WIDTH, x_min = -FRAME_WIDTH,
x_max = FRAME_WIDTH, x_max = FRAME_WIDTH,
num_anchor_points = 1000 n_components = 0.001
) )
long_graph.set_color(BLUE) long_graph.set_color(BLUE)
long_graph.next_to(graph, UP, MED_LARGE_BUFF) long_graph.next_to(graph, UP, MED_LARGE_BUFF)