Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Prev Previous commit
Next Next commit
Speed up computation when f is expensive.
Now we call axes.plot with stored values of f instead of calling f on x.
This will improve performance when f is an expensive function.
  • Loading branch information
dmcdougall committed Mar 29, 2012
commit 04bb1a0d233e1a6fd83393d2abd5b3e3a03ca61f
31 changes: 18 additions & 13 deletions lib/matplotlib/fplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def fplot(axes, f, limits, *args, **kwargs):
n = kwargs.pop('tol', 50)

x = np.linspace(limits[0], limits[1], n)
f_vals = [f(xi) for xi in x]

# Bisect abscissa until the gradient error changes by less than tol
within_tol = False

while not within_tol:
within_tol = True
new_pts = []
new_f = []
for i in xrange(len(x)-1):
# Make sure the step size is not pointlessly small.
# This is a numerical check to prevent silly roundoff errors.
Expand All @@ -70,14 +72,13 @@ def fplot(axes, f, limits, *args, **kwargs):
# If the function values are too close, the payoff is
# negligible, so skip them.
f_new = f(x_new) # Used later, so store it
f_i = f(x[i]) # Used later, so store it
if abs(x_new - x[i]) < min_step or abs(f_new - f_i) < min_step:
if abs(x_new - x[i]) < min_step or abs(f_new - f_vals[i]) < min_step:
continue

# Compare gradients of actual f and linear approximation
# FIXME: What if f(x[i]) or f(x[i+1]) is nan?
dx = abs(x[i+1] - x[i])
f_interp = (f(x[i+1]) + f_i)
f_interp = (f_vals[i+1] + f_vals[i])

# This line is the absolute error of the gradient
grad_error = np.abs(f_interp - 2.0 * f_new) / dx
Expand All @@ -87,26 +88,30 @@ def fplot(axes, f, limits, *args, **kwargs):
if grad_error > tol:
within_tol = False
new_pts.append(x_new)
new_f.append(f_new)

if not within_tol:
# Not sure this is the best way to do this...
# Merge the subdivision points into the array of abscissae
x = merge_pts(x, new_pts)
x, f_vals = merge_pts(x, new_pts, f_vals, new_f)

return axes.plot(x, f(x))
return axes.plot(x, f_vals)

def merge_pts(a, b):
def merge_pts(xs, xs_sub, fs, fs_sub):
x = []
f = []
ia = 0
ib = 0
while ib < len(b):
if b[ib] < a[ia]:
x.append(b[ib])
while ib < len(xs_sub):
if xs_sub[ib] < xs[ia]:
x.append(xs_sub[ib])
f.append(fs_sub[ib])
ib += 1
else:
x.append(a[ia])
x.append(xs[ia])
f.append(fs[ia])
ia += 1
if ia < len(a):
return np.append(x, a[ia::])
if ia < len(xs):
return np.append(x, xs[ia::]), np.append(f, fs[ia::])
else:
return np.array(x)
return np.array(x), np.array(f)