From cb7f555cd567dbe50eebabc8d814f86ccb34a168 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 22 Jan 2019 21:16:39 +0100 Subject: [PATCH] In imsave()'s Pillow-handled case, don't create a temporary figure. Avoids accidentally changing the image shape when dividing and multiplying by dpi. --- lib/matplotlib/image.py | 47 +++++++++++++++++++++++------- lib/matplotlib/tests/test_image.py | 24 ++++++++++----- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index 2924b74899cc..e8ac3de31183 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -7,6 +7,7 @@ from math import ceil import os import logging +from pathlib import Path import urllib.parse import urllib.request @@ -1432,24 +1433,48 @@ def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, The DPI to store in the metadata of the file. This does not affect the resolution of the output image. """ - from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure if isinstance(fname, os.PathLike): fname = os.fspath(fname) - if (format == 'png' - or (format is None - and isinstance(fname, str) - and fname.lower().endswith('.png'))): - image = AxesImage(None, cmap=cmap, origin=origin) - image.set_data(arr) - image.set_clim(vmin, vmax) - image.write_png(fname) - else: + if format is None: + format = (Path(fname).suffix[1:] if isinstance(fname, str) + else rcParams["savefig.format"]).lower() + if format in ["pdf", "ps", "eps", "svg"]: + # Vector formats that are not handled by PIL. fig = Figure(dpi=dpi, frameon=False) - FigureCanvas(fig) fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin, resize=True) fig.savefig(fname, dpi=dpi, format=format, transparent=True) + else: + # Don't bother creating an image; this avoids rounding errors on the + # size when dividing and then multiplying by dpi. + sm = cm.ScalarMappable(cmap=cmap) + sm.set_clim(vmin, vmax) + if origin is None: + origin = rcParams["image.origin"] + if origin == "lower": + arr = arr[::-1] + rgba = sm.to_rgba(arr, bytes=True) + if format == "png": + _png.write_png(rgba, fname, dpi=dpi) + else: + try: + from PIL import Image + except ImportError as exc: + raise ImportError( + f"Saving to {format} requires Pillow") from exc + pil_shape = (rgba.shape[1], rgba.shape[0]) + image = Image.frombuffer( + "RGBA", pil_shape, rgba, "raw", "RGBA", 0, 1) + if format in ["jpg", "jpeg"]: + format = "jpeg" # Pillow doesn't recognize "jpg". + color = tuple( + int(x * 255) + for x in mcolors.to_rgb(rcParams["savefig.facecolor"])) + background = Image.new("RGB", pil_shape, color) + background.paste(image, image) + image = background + image.save(fname, format=format, dpi=(dpi, dpi)) def pil_to_array(pilImage): diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index d3484fb5f838..02fec2731ea1 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -121,7 +121,12 @@ def test_imread_fspath(): assert np.sum(img) == 134184960 -def test_imsave(): +@pytest.mark.parametrize("fmt", ["png", "jpg", "jpeg", "tiff"]) +def test_imsave(fmt): + if fmt in ["jpg", "jpeg", "tiff"]: + pytest.importorskip("PIL") + has_alpha = fmt not in ["jpg", "jpeg"] + # The goal here is that the user can specify an output logical DPI # for the image, but this will not actually add any extra pixels # to the image, it will merely be used for metadata purposes. @@ -130,22 +135,25 @@ def test_imsave(): # == 100) and read the resulting PNG files back in and make sure # the data is 100% identical. np.random.seed(1) - data = np.random.rand(256, 128) + # The height of 1856 pixels was selected because going through creating an + # actual dpi=100 figure to save the image to a Pillow-provided format would + # cause a rounding error resulting in a final image of shape 1855. + data = np.random.rand(1856, 2) buff_dpi1 = io.BytesIO() - plt.imsave(buff_dpi1, data, dpi=1) + plt.imsave(buff_dpi1, data, format=fmt, dpi=1) buff_dpi100 = io.BytesIO() - plt.imsave(buff_dpi100, data, dpi=100) + plt.imsave(buff_dpi100, data, format=fmt, dpi=100) buff_dpi1.seek(0) - arr_dpi1 = plt.imread(buff_dpi1) + arr_dpi1 = plt.imread(buff_dpi1, format=fmt) buff_dpi100.seek(0) - arr_dpi100 = plt.imread(buff_dpi100) + arr_dpi100 = plt.imread(buff_dpi100, format=fmt) - assert arr_dpi1.shape == (256, 128, 4) - assert arr_dpi100.shape == (256, 128, 4) + assert arr_dpi1.shape == (1856, 2, 3 + has_alpha) + assert arr_dpi100.shape == (1856, 2, 3 + has_alpha) assert_array_equal(arr_dpi1, arr_dpi100)