Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 605fd3c

Browse files
authored
Merge pull request #10220 from Zac-HD/imshow-rgb-fixes
FIX: Clip float RGB data to valid range for imshow
2 parents 2ed292d + 7f97698 commit 605fd3c

File tree

7 files changed

+68
-7
lines changed

7 files changed

+68
-7
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
`Axes.imshow` clips RGB values to the valid range
2+
-------------------------------------------------
3+
4+
When `Axes.imshow` is passed an RGB or RGBA value with out-of-range
5+
values, it now logs a warning and clips them to the valid range.
6+
The old behaviour, wrapping back in to the range, often hid outliers
7+
and made interpreting RGB images unreliable.

doc/users/credits.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ Yu Feng,
386386
Yunfei Yang,
387387
Yuri D'Elia,
388388
Yuval Langer,
389+
Zac Hatfield-Dodds,
389390
Zach Pincus,
390391
Zair Mubashar,
391392
alex,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
`Axes.imshow` clips RGB values to the valid range
2+
-------------------------------------------------
3+
4+
When `Axes.imshow` is passed an RGB or RGBA value with out-of-range
5+
values, it now logs a warning and clips them to the valid range.
6+
The old behaviour, wrapping back in to the range, often hid outliers
7+
and made interpreting RGB images unreliable.

lib/matplotlib/axes/_axes.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5319,10 +5319,14 @@ def imshow(self, X, cmap=None, norm=None, aspect=None,
53195319
- MxNx3 -- RGB (float or uint8)
53205320
- MxNx4 -- RGBA (float or uint8)
53215321
5322-
The value for each component of MxNx3 and MxNx4 float arrays
5323-
should be in the range 0.0 to 1.0. MxN arrays are mapped
5324-
to colors based on the `norm` (mapping scalar to scalar)
5325-
and the `cmap` (mapping the normed scalar to a color).
5322+
MxN arrays are mapped to colors based on the `norm` (mapping
5323+
scalar to scalar) and the `cmap` (mapping the normed scalar to
5324+
a color).
5325+
5326+
Elements of RGB and RGBA arrays represent pixels of an MxN image.
5327+
All values should be in the range [0 .. 1] for floats or
5328+
[0 .. 255] for integers. Out-of-range values will be clipped to
5329+
these bounds.
53265330
53275331
cmap : `~matplotlib.colors.Colormap`, optional, default: None
53285332
If None, default to rc `image.cmap` value. `cmap` is ignored
@@ -5364,7 +5368,8 @@ def imshow(self, X, cmap=None, norm=None, aspect=None,
53645368
settings for `vmin` and `vmax` will be ignored.
53655369
53665370
alpha : scalar, optional, default: None
5367-
The alpha blending value, between 0 (transparent) and 1 (opaque)
5371+
The alpha blending value, between 0 (transparent) and 1 (opaque).
5372+
The ``alpha`` argument is ignored for RGBA input data.
53685373
53695374
origin : ['upper' | 'lower'], optional, default: None
53705375
Place the [0,0] index of the array in the upper left or lower left

lib/matplotlib/cm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
259259
xx = (xx * 255).astype(np.uint8)
260260
elif xx.dtype == np.uint8:
261261
if not bytes:
262-
xx = xx.astype(float) / 255
262+
xx = xx.astype(np.float32) / 255
263263
else:
264264
raise ValueError("Image RGB array must be uint8 or "
265265
"floating point; found %s" % xx.dtype)

lib/matplotlib/image.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from math import ceil
1515
import os
16+
import logging
1617

1718
import numpy as np
1819

@@ -34,6 +35,8 @@
3435
from matplotlib.transforms import (Affine2D, BboxBase, Bbox, BboxTransform,
3536
IdentityTransform, TransformedBbox)
3637

38+
_log = logging.getLogger(__name__)
39+
3740
# map interpolation strings to module constants
3841
_interpd_ = {
3942
'none': _image.NEAREST, # fall back to nearest when not supported
@@ -623,6 +626,23 @@ def set_data(self, A):
623626
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
624627
raise TypeError("Invalid dimensions for image data")
625628

629+
if self._A.ndim == 3:
630+
# If the input data has values outside the valid range (after
631+
# normalisation), we issue a warning and then clip X to the bounds
632+
# - otherwise casting wraps extreme values, hiding outliers and
633+
# making reliable interpretation impossible.
634+
high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1
635+
if self._A.min() < 0 or high < self._A.max():
636+
_log.warning(
637+
'Clipping input data to the valid range for imshow with '
638+
'RGB data ([0..1] for floats or [0..255] for integers).'
639+
)
640+
self._A = np.clip(self._A, 0, high)
641+
# Cast unsupported integer types to uint8
642+
if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype,
643+
np.integer):
644+
self._A = self._A.astype(np.uint8)
645+
626646
self._imcache = None
627647
self._rgbacache = None
628648
self.stale = True

lib/matplotlib/tests/test_image.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def test_minimized_rasterized():
620620
def test_load_from_url():
621621
req = six.moves.urllib.request.urlopen(
622622
"http://matplotlib.org/_static/logo_sidebar_horiz.png")
623-
Z = plt.imread(req)
623+
plt.imread(req)
624624

625625

626626
@image_comparison(baseline_images=['log_scale_image'],
@@ -813,6 +813,27 @@ def test_imshow_no_warn_invalid():
813813
assert len(warns) == 0
814814

815815

816+
@pytest.mark.parametrize(
817+
'dtype', [np.dtype(s) for s in 'u2 u4 i2 i4 i8 f4 f8'.split()])
818+
def test_imshow_clips_rgb_to_valid_range(dtype):
819+
arr = np.arange(300, dtype=dtype).reshape((10, 10, 3))
820+
if dtype.kind != 'u':
821+
arr -= 10
822+
too_low = arr < 0
823+
too_high = arr > 255
824+
if dtype.kind == 'f':
825+
arr = arr / 255
826+
_, ax = plt.subplots()
827+
out = ax.imshow(arr).get_array()
828+
assert (out[too_low] == 0).all()
829+
if dtype.kind == 'f':
830+
assert (out[too_high] == 1).all()
831+
assert out.dtype.kind == 'f'
832+
else:
833+
assert (out[too_high] == 255).all()
834+
assert out.dtype == np.uint8
835+
836+
816837
@image_comparison(baseline_images=['imshow_flatfield'],
817838
remove_text=True, style='mpl20',
818839
extensions=['png'])

0 commit comments

Comments
 (0)