|
11 | 11 | ) |
12 | 12 |
|
13 | 13 |
|
14 | | -def _np_real_roots_in_01(coeffs): |
15 | | - """Reference implementation using np.roots for comparison.""" |
16 | | - coeffs = np.asarray(coeffs) |
17 | | - # np.roots expects descending order (highest power first) |
18 | | - all_roots = np.roots(coeffs[::-1]) |
19 | | - # Filter to real roots in [0, 1] |
20 | | - real_mask = np.abs(all_roots.imag) < 1e-10 |
21 | | - real_roots = all_roots[real_mask].real |
22 | | - in_range = (real_roots >= -1e-10) & (real_roots <= 1 + 1e-10) |
23 | | - return np.sort(np.clip(real_roots[in_range], 0, 1)) |
24 | | - |
25 | | - |
26 | | -@pytest.mark.parametrize("coeffs, expected", [ |
27 | | - ([-0.5, 1], [0.5]), |
28 | | - ([-2, 1], []), # roots: [2.0], not in [0, 1] |
29 | | - ([0.1875, -1, 1], [0.25, 0.75]), |
30 | | - ([1, -2.5, 1], [0.5]), # roots: [0.5, 2.0], only one in [0, 1] |
31 | | - ([1, 0, 1], []), # roots: [+-i], not real |
32 | | - ([-0.08, 0.66, -1.5, 1], [0.2, 0.5, 0.8]), |
33 | | - ([5], []), |
34 | | - ([0, 0, 0], []), |
35 | | - ([0, -0.5, 1], [0.0, 0.5]), |
36 | | - ([0.5, -1.5, 1], [0.5, 1.0]), |
| 14 | +@pytest.mark.parametrize("roots, expected_in_01", [ |
| 15 | + ([0.5], [0.5]), |
| 16 | + ([0.25, 0.75], [0.25, 0.75]), |
| 17 | + ([0.2, 0.5, 0.8], [0.2, 0.5, 0.8]), |
| 18 | + ([0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]), |
| 19 | + ([0.0, 0.5], [0.0, 0.5]), |
| 20 | + ([0.5, 1.0], [0.5, 1.0]), |
| 21 | + ([2.0], []), # outside [0, 1] |
| 22 | + ([0.5, 2.0], [0.5]), # one in, one out |
| 23 | + ([-1j, 1j], []), # complex roots |
| 24 | + ([0.5, -1j, 1j], [0.5]), # mix of real and complex |
| 25 | + ([0.3, 0.3], [0.3, 0.3]), # repeated root |
37 | 26 | ]) |
38 | | -def test_real_roots_in_01_known_cases(coeffs, expected): |
39 | | - """Test _real_roots_in_01 against known values and np.roots reference.""" |
40 | | - result = _real_roots_in_01(coeffs) |
41 | | - np_expected = _np_real_roots_in_01(coeffs) |
42 | | - assert_allclose(result, expected, atol=1e-10) |
43 | | - assert_allclose(result, np_expected, atol=1e-10) |
| 27 | +def test_real_roots_in_01(roots, expected_in_01): |
| 28 | + roots = np.array(roots) |
| 29 | + coeffs = np.poly(roots)[::-1] # np.poly gives descending, we need ascending |
| 30 | + result = _real_roots_in_01(coeffs.real) |
| 31 | + assert_allclose(result, expected_in_01, atol=1e-10) |
44 | 32 |
|
45 | 33 |
|
46 | | -@pytest.mark.parametrize("degree", range(1, 11)) |
47 | | -def test_real_roots_in_01_random(degree): |
48 | | - """Test random polynomials against np.roots.""" |
49 | | - rng = np.random.default_rng(seed=0) |
50 | | - coeffs = rng.uniform(-10, 10, size=degree + 1) |
51 | | - result = _real_roots_in_01(coeffs) |
52 | | - expected = _np_real_roots_in_01(coeffs) |
53 | | - assert len(result) == len(expected) |
54 | | - if len(result) > 0: |
55 | | - assert_allclose(result, expected, atol=1e-8) |
| 34 | +@pytest.mark.parametrize("coeffs", [[5], [0, 0, 0]]) |
| 35 | +def test_real_roots_in_01_no_roots(coeffs): |
| 36 | + assert len(_real_roots_in_01(coeffs)) == 0 |
56 | 37 |
|
57 | 38 |
|
58 | 39 | def test_split_bezier_with_large_values(): |
|
0 commit comments