Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 44 additions & 46 deletions astropy/units/tests/test_logarithmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

pu_sample = (u.dimensionless_unscaled, u.m, u.g / u.s**2, u.Jy)

mJy = np.arange(1.0, 5.0).reshape(2, 2) * u.mag(u.Jy)
m1 = np.arange(1.0, 5.5, 0.5).reshape(3, 3) * u.mag()
log_quantity_parametrization = pytest.mark.parametrize(
"mag", [mJy, m1], ids=lambda x: str(x.unit.physical_unit.physical_type)
)


class TestLogUnitCreation:
def test_logarithmic_units(self):
Expand Down Expand Up @@ -973,11 +979,6 @@ def test_comparison(self):


class TestLogQuantityMethods:
def setup_method(self):
self.mJy = np.arange(1.0, 5.0).reshape(2, 2) * u.mag(u.Jy)
self.m1 = np.arange(1.0, 5.5, 0.5).reshape(3, 3) * u.mag()
self.mags = (self.mJy, self.m1)

@pytest.mark.parametrize(
"method",
(
Expand All @@ -992,59 +993,56 @@ def setup_method(self):
"ediff1d",
),
)
def test_always_ok(self, method):
for mag in self.mags:
res = getattr(mag, method)()
assert np.all(res.value == getattr(mag._function_view, method)().value)
if method in ("std", "diff", "ediff1d"):
assert res.unit == u.mag()
elif method == "var":
assert res.unit == u.mag**2
else:
assert res.unit == mag.unit
@log_quantity_parametrization
def test_always_ok(self, method, mag):
res = getattr(mag, method)()
assert np.all(res.value == getattr(mag._function_view, method)().value)
if method in ("std", "diff", "ediff1d"):
assert res.unit == u.mag()
elif method == "var":
assert res.unit == u.mag**2
else:
assert res.unit == mag.unit

@pytest.mark.skipif(not NUMPY_LT_2_0, reason="ptp method removed in numpy 2.0")
def test_always_ok_ptp(self):
for mag in self.mags:
res = mag.ptp()
assert np.all(res.value == mag._function_view.ptp().value)
assert res.unit == u.mag()
@log_quantity_parametrization
def test_always_ok_ptp(self, mag):
res = mag.ptp()
assert np.all(res.value == mag._function_view.ptp().value)
assert res.unit == u.mag()

@log_quantity_parametrization
def test_clip(self, mag):
assert np.all(
mag.clip(2.0 * mag.unit, 4.0 * mag.unit).value == mag.value.clip(2.0, 4.0)
)

def test_clip(self):
for mag in self.mags:
assert np.all(
mag.clip(2.0 * mag.unit, 4.0 * mag.unit).value
== mag.value.clip(2.0, 4.0)
)
@pytest.mark.parametrize("method", ("sum", "cumsum"))
def test_ok_if_dimensionless(self, method):
res = getattr(m1, method)()
assert np.all(res.value == getattr(m1, method)().value)
assert res.unit == m1.unit

@pytest.mark.parametrize("method", ("sum", "cumsum"))
def test_only_ok_if_dimensionless(self, method):
res = getattr(self.m1, method)()
assert np.all(res.value == getattr(self.m1._function_view, method)().value)
assert res.unit == self.m1.unit
def test_not_ok_if_not_dimensionless(self, method):
with pytest.raises(TypeError):
getattr(self.mJy, method)()
getattr(mJy, method)()

def test_dot(self):
assert np.all(self.m1.dot(self.m1).value == self.m1.value.dot(self.m1.value))
assert np.all(m1.dot(m1).value == m1.value.dot(m1.value))

@pytest.mark.parametrize("method", ("prod", "cumprod"))
def test_never_ok(self, method):
with pytest.raises(TypeError):
getattr(self.mJy, method)()
@log_quantity_parametrization
def test_never_ok(self, method, mag):
with pytest.raises(TypeError):
getattr(self.m1, method)()
getattr(mag, method)()


class TestLogQuantityFunctions:
# TODO: add tests for all supported functions!
def setup_method(self):
self.mJy = np.arange(1.0, 5.0).reshape(2, 2) * u.mag(u.Jy)
self.m1 = np.arange(1.0, 5.5, 0.5).reshape(3, 3) * u.mag()
self.mags = (self.mJy, self.m1)

def test_ptp(self):
for mag in self.mags:
res = np.ptp(mag)
assert np.all(res.value == np.ptp(mag._function_view).value)
assert res.unit == u.mag()

@log_quantity_parametrization
def test_ptp(self, mag):
res = np.ptp(mag)
assert np.all(res.value == np.ptp(mag._function_view).value)
assert res.unit == u.mag()
Loading