|
9 | 9 | import numpy as np |
10 | 10 |
|
11 | 11 | from matplotlib import _api |
12 | | -from numpy.polynomial.polynomial import polyval as _polyval |
13 | | - |
14 | | - |
15 | | -def _bisect_root_finder(f, a, b, tol=1e-12, max_iter=64): |
16 | | - """Find root of f in [a, b] using bisection. Assumes sign change exists.""" |
17 | | - fa = f(a) |
18 | | - for _ in range(max_iter): |
19 | | - mid = (a + b) * 0.5 |
20 | | - fm = f(mid) |
21 | | - if abs(fm) < tol or (b - a) < tol: |
22 | | - return mid |
23 | | - if fa * fm < 0: |
24 | | - b = mid |
25 | | - else: |
26 | | - a, fa = mid, fm |
27 | | - return (a + b) * 0.5 |
28 | | - |
29 | | - |
30 | | -def _bisected_roots_in_01(coeffs): |
31 | | - """ |
32 | | - Find real roots of polynomial in [0, 1] using sampling and bisection. |
33 | | - coeffs in ascending order: c0 + c1*x + c2*x**2 + ... |
34 | | - """ |
35 | | - deg = len(coeffs) - 1 |
36 | | - n_samples = max(8, deg * 2) |
37 | | - ts = np.linspace(0, 1, n_samples) |
38 | | - vals = _polyval(ts, coeffs) |
39 | | - |
40 | | - signs = np.sign(vals) |
41 | | - sign_changes = np.where((signs[:-1] != signs[1:]) & (signs[:-1] != 0))[0] |
42 | | - |
43 | | - roots = [] |
44 | | - |
45 | | - def f(t): |
46 | | - return _polyval(t, coeffs) |
47 | | - |
48 | | - max_iter = 53 # float64 fractional precision for [0, 1] interval |
49 | | - for i in sign_changes: |
50 | | - roots.append(_bisect_root_finder(f, ts[i], ts[i + 1], max_iter=max_iter)) |
51 | | - |
52 | | - # Check endpoints |
53 | | - if abs(vals[0]) < 1e-12: |
54 | | - roots.insert(0, 0.0) |
55 | | - if abs(vals[-1]) < 1e-12 and (not roots or abs(roots[-1] - 1.0) > 1e-10): |
56 | | - roots.append(1.0) |
57 | | - |
58 | | - return np.asarray(roots) |
59 | 12 |
|
60 | 13 |
|
61 | 14 | def _quadratic_roots_in_01(c0, c1, c2): |
@@ -91,10 +44,10 @@ def _real_roots_in_01(coeffs): |
91 | 44 | """ |
92 | 45 | Find real roots of a polynomial in the interval [0, 1]. |
93 | 46 |
|
94 | | - This is optimized for finding roots only in [0, 1], which is faster than |
95 | | - computing all roots with `numpy.roots` and filtering. For polynomials of |
96 | | - degree <= 2, closed-form solutions are used. For higher degrees, sampling |
97 | | - and bisection are used. |
| 47 | + For polynomials of degree <= 2, closed-form solutions are used. |
| 48 | + For higher degrees, `numpy.roots` is used as a fallback. In practice, |
| 49 | + matplotlib only ever uses cubic bezier curves and axis_aligned_extrema() |
| 50 | + differentiates, so we only use find roots for degree <= 2. |
98 | 51 |
|
99 | 52 | Parameters |
100 | 53 | ---------- |
@@ -123,7 +76,13 @@ def _real_roots_in_01(coeffs): |
123 | 76 | elif deg == 2: |
124 | 77 | roots = _quadratic_roots_in_01(coeffs[0], coeffs[1], coeffs[2]) |
125 | 78 | else: |
126 | | - roots = _bisected_roots_in_01(coeffs[:deg + 1]) |
| 79 | + # np.roots expects descending order (highest power first) |
| 80 | + eps = 1e-10 |
| 81 | + all_roots = np.roots(coeffs[deg::-1]) |
| 82 | + real_mask = np.abs(all_roots.imag) < eps |
| 83 | + real_roots = all_roots[real_mask].real |
| 84 | + in_range = (real_roots >= -eps) & (real_roots <= 1 + eps) |
| 85 | + roots = np.clip(real_roots[in_range], 0, 1) |
127 | 86 |
|
128 | 87 | return np.sort(roots) |
129 | 88 |
|
|
0 commit comments