Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4adb996

Browse files
committed
Correctly apply PNG palette when building ImageBase through Pillow.
1 parent e551980 commit 4adb996

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

lib/matplotlib/image.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,15 @@ def set_data(self, A):
665665
666666
Parameters
667667
----------
668-
A : array-like
668+
A : array-like or `PIL.Image.Image`
669669
"""
670+
try:
671+
from PIL import Image
672+
except ImportError:
673+
pass
674+
else:
675+
if isinstance(A, Image.Image):
676+
A = pil_to_array(A) # Needed e.g. to apply png palette.
670677
self._A = cbook.safe_masked_invalid(A, copy=True)
671678

672679
if (self._A.dtype != np.uint8 and

lib/matplotlib/tests/test_image.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numpy.testing import assert_array_equal
1515

1616
from matplotlib import (
17-
colors, image as mimage, patches, pyplot as plt,
17+
colors, image as mimage, patches, pyplot as plt, style,
1818
rc_context, rcParams)
1919
from matplotlib.cbook import MatplotlibDeprecationWarning
2020
from matplotlib.image import (AxesImage, BboxImage, FigureImage,
@@ -117,11 +117,15 @@ def test_image_python_io():
117117

118118
@check_figures_equal()
119119
def test_imshow_pil(fig_test, fig_ref):
120-
pytest.importorskip("PIL")
121-
img = plt.imread(os.path.join(os.path.dirname(__file__),
122-
'baseline_images', 'test_image', 'uint16.tif'))
123-
fig_test.subplots().imshow(img)
124-
fig_ref.subplots().imshow(np.asarray(img))
120+
style.use("default")
121+
PIL = pytest.importorskip("PIL")
122+
png_path = Path(__file__).parent / "baseline_images/pngsuite/basn3p04.png"
123+
tiff_path = Path(__file__).parent / "baseline_images/test_image/uint16.tif"
124+
axs[0].imshow(PIL.Image.open(png_path))
125+
axs[1].imshow(PIL.Image.open(tiff_path))
126+
axs = fig_ref.subplots(2)
127+
axs[0].imshow(plt.imread(str(png_path)))
128+
axs[1].imshow(plt.imread(tiff_path))
125129

126130

127131
def test_imread_pil_uint16():

0 commit comments

Comments
 (0)