Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e353367

Browse files
authored
Merge pull request #28458 from ianthomas23/28448_image_resample_dtype_comparisons
Correct numpy dtype comparisons in image_resample
2 parents f54dd51 + acfe975 commit e353367

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

lib/matplotlib/tests/test_image.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,3 +1576,20 @@ def test_non_transdata_image_does_not_touch_aspect():
15761576
assert ax.get_aspect() == 1
15771577
ax.imshow(im, transform=ax.transAxes, aspect=2)
15781578
assert ax.get_aspect() == 2
1579+
1580+
1581+
@pytest.mark.parametrize(
1582+
'dtype',
1583+
('float64', 'float32', 'int16', 'uint16', 'int8', 'uint8'),
1584+
)
1585+
@pytest.mark.parametrize('ndim', (2, 3))
1586+
def test_resample_dtypes(dtype, ndim):
1587+
# Issue 28448, incorrect dtype comparisons in C++ image_resample can raise
1588+
# ValueError: arrays must be of dtype byte, short, float32 or float64
1589+
rng = np.random.default_rng(4181)
1590+
shape = (2, 2) if ndim == 2 else (2, 2, 3)
1591+
data = rng.uniform(size=shape).astype(np.dtype(dtype, copy=True))
1592+
fig, ax = plt.subplots()
1593+
axes_image = ax.imshow(data)
1594+
# Before fix the following raises ValueError for some dtypes.
1595+
axes_image.make_image(None)[0]

src/_image_wrapper.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,20 @@ image_resample(py::array input_array,
173173

174174
if (auto resampler =
175175
(ndim == 2) ? (
176-
(dtype.is(py::dtype::of<std::uint8_t>())) ? resample<agg::gray8> :
177-
(dtype.is(py::dtype::of<std::int8_t>())) ? resample<agg::gray8> :
178-
(dtype.is(py::dtype::of<std::uint16_t>())) ? resample<agg::gray16> :
179-
(dtype.is(py::dtype::of<std::int16_t>())) ? resample<agg::gray16> :
180-
(dtype.is(py::dtype::of<float>())) ? resample<agg::gray32> :
181-
(dtype.is(py::dtype::of<double>())) ? resample<agg::gray64> :
176+
(dtype.equal(py::dtype::of<std::uint8_t>())) ? resample<agg::gray8> :
177+
(dtype.equal(py::dtype::of<std::int8_t>())) ? resample<agg::gray8> :
178+
(dtype.equal(py::dtype::of<std::uint16_t>())) ? resample<agg::gray16> :
179+
(dtype.equal(py::dtype::of<std::int16_t>())) ? resample<agg::gray16> :
180+
(dtype.equal(py::dtype::of<float>())) ? resample<agg::gray32> :
181+
(dtype.equal(py::dtype::of<double>())) ? resample<agg::gray64> :
182182
nullptr) : (
183183
// ndim == 3
184-
(dtype.is(py::dtype::of<std::uint8_t>())) ? resample<agg::rgba8> :
185-
(dtype.is(py::dtype::of<std::int8_t>())) ? resample<agg::rgba8> :
186-
(dtype.is(py::dtype::of<std::uint16_t>())) ? resample<agg::rgba16> :
187-
(dtype.is(py::dtype::of<std::int16_t>())) ? resample<agg::rgba16> :
188-
(dtype.is(py::dtype::of<float>())) ? resample<agg::rgba32> :
189-
(dtype.is(py::dtype::of<double>())) ? resample<agg::rgba64> :
184+
(dtype.equal(py::dtype::of<std::uint8_t>())) ? resample<agg::rgba8> :
185+
(dtype.equal(py::dtype::of<std::int8_t>())) ? resample<agg::rgba8> :
186+
(dtype.equal(py::dtype::of<std::uint16_t>())) ? resample<agg::rgba16> :
187+
(dtype.equal(py::dtype::of<std::int16_t>())) ? resample<agg::rgba16> :
188+
(dtype.equal(py::dtype::of<float>())) ? resample<agg::rgba32> :
189+
(dtype.equal(py::dtype::of<double>())) ? resample<agg::rgba64> :
190190
nullptr)) {
191191
Py_BEGIN_ALLOW_THREADS
192192
resampler(

0 commit comments

Comments
 (0)