From 0aa15c8785730d5b6dd650cddff92de74e822563 Mon Sep 17 00:00:00 2001 From: Eero Vaher Date: Tue, 29 Jul 2025 23:44:43 +0300 Subject: [PATCH] Improve logarithmic quantity tests parametrization Using `pytest.mark.parametrize()` is better than having a `for`-loop inside a test. One parameter causing a test failure does not prevent other parameters from being tested. Furthermore, the output of `pytest` is much better if a test fails or if `pytest` is invoked with the `--verbose` option. --- astropy/units/tests/test_logarithmic.py | 90 ++++++++++++------------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/astropy/units/tests/test_logarithmic.py b/astropy/units/tests/test_logarithmic.py index db4a0b1cfb12..8772e94bd13e 100644 --- a/astropy/units/tests/test_logarithmic.py +++ b/astropy/units/tests/test_logarithmic.py @@ -22,6 +22,12 @@ pu_sample = (u.dimensionless_unscaled, u.m, u.g / u.s**2, u.Jy) +mJy = np.arange(1.0, 5.0).reshape(2, 2) * u.mag(u.Jy) +m1 = np.arange(1.0, 5.5, 0.5).reshape(3, 3) * u.mag() +log_quantity_parametrization = pytest.mark.parametrize( + "mag", [mJy, m1], ids=lambda x: str(x.unit.physical_unit.physical_type) +) + class TestLogUnitCreation: def test_logarithmic_units(self): @@ -973,11 +979,6 @@ def test_comparison(self): class TestLogQuantityMethods: - def setup_method(self): - self.mJy = np.arange(1.0, 5.0).reshape(2, 2) * u.mag(u.Jy) - self.m1 = np.arange(1.0, 5.5, 0.5).reshape(3, 3) * u.mag() - self.mags = (self.mJy, self.m1) - @pytest.mark.parametrize( "method", ( @@ -992,59 +993,56 @@ def setup_method(self): "ediff1d", ), ) - def test_always_ok(self, method): - for mag in self.mags: - res = getattr(mag, method)() - assert np.all(res.value == getattr(mag._function_view, method)().value) - if method in ("std", "diff", "ediff1d"): - assert res.unit == u.mag() - elif method == "var": - assert res.unit == u.mag**2 - else: - assert res.unit == mag.unit + @log_quantity_parametrization + def test_always_ok(self, method, mag): + res = getattr(mag, method)() + assert np.all(res.value == getattr(mag._function_view, method)().value) + if method in ("std", "diff", "ediff1d"): + assert res.unit == u.mag() + elif method == "var": + assert res.unit == u.mag**2 + else: + assert res.unit == mag.unit @pytest.mark.skipif(not NUMPY_LT_2_0, reason="ptp method removed in numpy 2.0") - def test_always_ok_ptp(self): - for mag in self.mags: - res = mag.ptp() - assert np.all(res.value == mag._function_view.ptp().value) - assert res.unit == u.mag() + @log_quantity_parametrization + def test_always_ok_ptp(self, mag): + res = mag.ptp() + assert np.all(res.value == mag._function_view.ptp().value) + assert res.unit == u.mag() + + @log_quantity_parametrization + def test_clip(self, mag): + assert np.all( + mag.clip(2.0 * mag.unit, 4.0 * mag.unit).value == mag.value.clip(2.0, 4.0) + ) - def test_clip(self): - for mag in self.mags: - assert np.all( - mag.clip(2.0 * mag.unit, 4.0 * mag.unit).value - == mag.value.clip(2.0, 4.0) - ) + @pytest.mark.parametrize("method", ("sum", "cumsum")) + def test_ok_if_dimensionless(self, method): + res = getattr(m1, method)() + assert np.all(res.value == getattr(m1, method)().value) + assert res.unit == m1.unit @pytest.mark.parametrize("method", ("sum", "cumsum")) - def test_only_ok_if_dimensionless(self, method): - res = getattr(self.m1, method)() - assert np.all(res.value == getattr(self.m1._function_view, method)().value) - assert res.unit == self.m1.unit + def test_not_ok_if_not_dimensionless(self, method): with pytest.raises(TypeError): - getattr(self.mJy, method)() + getattr(mJy, method)() def test_dot(self): - assert np.all(self.m1.dot(self.m1).value == self.m1.value.dot(self.m1.value)) + assert np.all(m1.dot(m1).value == m1.value.dot(m1.value)) @pytest.mark.parametrize("method", ("prod", "cumprod")) - def test_never_ok(self, method): - with pytest.raises(TypeError): - getattr(self.mJy, method)() + @log_quantity_parametrization + def test_never_ok(self, method, mag): with pytest.raises(TypeError): - getattr(self.m1, method)() + getattr(mag, method)() class TestLogQuantityFunctions: # TODO: add tests for all supported functions! - def setup_method(self): - self.mJy = np.arange(1.0, 5.0).reshape(2, 2) * u.mag(u.Jy) - self.m1 = np.arange(1.0, 5.5, 0.5).reshape(3, 3) * u.mag() - self.mags = (self.mJy, self.m1) - - def test_ptp(self): - for mag in self.mags: - res = np.ptp(mag) - assert np.all(res.value == np.ptp(mag._function_view).value) - assert res.unit == u.mag() + + @log_quantity_parametrization + def test_ptp(self, mag): + res = np.ptp(mag) + assert np.all(res.value == np.ptp(mag._function_view).value) + assert res.unit == u.mag()