Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7e3b70f

Browse files
Simplify bezier root finding
1 parent 52cd7dd commit 7e3b70f

2 files changed

Lines changed: 25 additions & 60 deletions

File tree

lib/matplotlib/bezier.py

Lines changed: 11 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,53 +9,6 @@
99
import numpy as np
1010

1111
from matplotlib import _api
12-
from numpy.polynomial.polynomial import polyval as _polyval
13-
14-
15-
def _bisect_root_finder(f, a, b, tol=1e-12, max_iter=64):
16-
"""Find root of f in [a, b] using bisection. Assumes sign change exists."""
17-
fa = f(a)
18-
for _ in range(max_iter):
19-
mid = (a + b) * 0.5
20-
fm = f(mid)
21-
if abs(fm) < tol or (b - a) < tol:
22-
return mid
23-
if fa * fm < 0:
24-
b = mid
25-
else:
26-
a, fa = mid, fm
27-
return (a + b) * 0.5
28-
29-
30-
def _bisected_roots_in_01(coeffs):
31-
"""
32-
Find real roots of polynomial in [0, 1] using sampling and bisection.
33-
coeffs in ascending order: c0 + c1*x + c2*x**2 + ...
34-
"""
35-
deg = len(coeffs) - 1
36-
n_samples = max(8, deg * 2)
37-
ts = np.linspace(0, 1, n_samples)
38-
vals = _polyval(ts, coeffs)
39-
40-
signs = np.sign(vals)
41-
sign_changes = np.where((signs[:-1] != signs[1:]) & (signs[:-1] != 0))[0]
42-
43-
roots = []
44-
45-
def f(t):
46-
return _polyval(t, coeffs)
47-
48-
max_iter = 53 # float64 fractional precision for [0, 1] interval
49-
for i in sign_changes:
50-
roots.append(_bisect_root_finder(f, ts[i], ts[i + 1], max_iter=max_iter))
51-
52-
# Check endpoints
53-
if abs(vals[0]) < 1e-12:
54-
roots.insert(0, 0.0)
55-
if abs(vals[-1]) < 1e-12 and (not roots or abs(roots[-1] - 1.0) > 1e-10):
56-
roots.append(1.0)
57-
58-
return np.asarray(roots)
5912

6013

6114
def _quadratic_roots_in_01(c0, c1, c2):
@@ -91,10 +44,10 @@ def _real_roots_in_01(coeffs):
9144
"""
9245
Find real roots of a polynomial in the interval [0, 1].
9346
94-
This is optimized for finding roots only in [0, 1], which is faster than
95-
computing all roots with `numpy.roots` and filtering. For polynomials of
96-
degree <= 2, closed-form solutions are used. For higher degrees, sampling
97-
and bisection are used.
47+
For polynomials of degree <= 2, closed-form solutions are used.
48+
For higher degrees, `numpy.roots` is used as a fallback. In practice,
49+
matplotlib only ever uses cubic bezier curves and axis_aligned_extrema()
50+
differentiates, so we only use find roots for degree <= 2.
9851
9952
Parameters
10053
----------
@@ -123,7 +76,13 @@ def _real_roots_in_01(coeffs):
12376
elif deg == 2:
12477
roots = _quadratic_roots_in_01(coeffs[0], coeffs[1], coeffs[2])
12578
else:
126-
roots = _bisected_roots_in_01(coeffs[:deg + 1])
79+
# np.roots expects descending order (highest power first)
80+
eps = 1e-10
81+
all_roots = np.roots(coeffs[deg::-1])
82+
real_mask = np.abs(all_roots.imag) < eps
83+
real_roots = all_roots[real_mask].real
84+
in_range = (real_roots >= -eps) & (real_roots <= 1 + eps)
85+
roots = np.clip(real_roots[in_range], 0, 1)
12786

12887
return np.sort(roots)
12988

lib/matplotlib/tests/test_bezier.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,26 @@ def test_real_roots_in_01_known_cases(coeffs, expected):
3939
"""Test _real_roots_in_01 against known values and np.roots reference."""
4040
result = _real_roots_in_01(coeffs)
4141
np_expected = _np_real_roots_in_01(coeffs)
42-
assert_allclose(result, expected, atol=1e-10)
43-
assert_allclose(result, np_expected, atol=1e-10)
42+
assert_allclose(result, expected, atol=1e-8)
43+
assert len(result) == len(np_expected)
44+
if len(result) > 0:
45+
assert_allclose(result, np_expected, atol=1e-4)
4446

4547

4648
@pytest.mark.parametrize("degree", range(1, 11))
4749
def test_real_roots_in_01_random(degree):
4850
"""Test random polynomials against np.roots."""
4951
rng = np.random.default_rng(seed=0)
50-
coeffs = rng.uniform(-10, 10, size=degree + 1)
51-
result = _real_roots_in_01(coeffs)
52-
expected = _np_real_roots_in_01(coeffs)
53-
assert len(result) == len(expected)
54-
if len(result) > 0:
55-
assert_allclose(result, expected, atol=1e-8)
52+
for _ in range(50):
53+
coeffs = rng.uniform(-10, 10, size=degree + 1)
54+
result = _real_roots_in_01(coeffs)
55+
expected = _np_real_roots_in_01(coeffs)
56+
assert len(result) == len(expected), (
57+
f"degree={degree}, coeffs={coeffs}: "
58+
f"got {result}, expected {expected}"
59+
)
60+
if len(result) > 0:
61+
assert_allclose(result, expected, atol=1e-8)
5662

5763

5864
def test_split_bezier_with_large_values():

0 commit comments

Comments
 (0)