From b13e31a2799f3d3227afc2d83187aa0eba6b373b Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Thu, 7 Nov 2024 23:32:19 -0500 Subject: [PATCH] TST: Calculate RMS and diff image in C++ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The current implementation is not slow, but uses a lot of memory per image. In `compare_images`, we have: - one actual and one expected image as uint8 (2×image) - both converted to int16 (though original is thrown away) (4×) which adds up to 4× the image allocated in this function. Then it calls `calculate_rms`, which has: - a difference between them as int16 (2×) - the difference cast to 64-bit float (8×) - the square of the difference as 64-bit float (though possibly the original difference was thrown away) (8×) which at its peak has 16× the image allocated in parallel. If the RMS is over the desired tolerance, then `save_diff_image` is called, which: - loads the actual and expected images _again_ as uint8 (2× image) - converts both to 64-bit float (throwing away the original) (16×) - calculates the difference (8×) - calculates the absolute value (8×) - multiples that by 10 (in-place, so no allocation) - clips to 0-255 (8×) - casts to uint8 (1×) which at peak uses 32× the image. So at their peak, `compare_images`→`calculate_rms` will have 20× the image allocated, and then `compare_images`→`save_diff_image` will have 36× the image allocated. This is generally not a problem, but on resource-constrained places like WASM, it can sometimes run out of memory just in `calculate_rms`. This implementation in C++ always allocates the diff image, even when not needed, but doesn't have all the temporaries, so it's a maximum of 3× the image size (plus a few scalar temporaries). --- lib/matplotlib/testing/compare.py | 13 ++-- lib/matplotlib/tests/test_compare_images.py | 27 +++++++ src/_image_wrapper.cpp | 79 +++++++++++++++++++++ 3 files changed, 110 insertions(+), 9 deletions(-) diff --git a/lib/matplotlib/testing/compare.py b/lib/matplotlib/testing/compare.py index 67897e76edcb..fa5cd89481b5 100644 --- a/lib/matplotlib/testing/compare.py +++ b/lib/matplotlib/testing/compare.py @@ -19,7 +19,7 @@ from PIL import Image import matplotlib as mpl -from matplotlib import cbook +from matplotlib import cbook, _image from matplotlib.testing.exceptions import ImageComparisonFailure _log = logging.getLogger(__name__) @@ -412,7 +412,7 @@ def compare_images(expected, actual, tol, in_decorator=False): The two given filenames may point to files which are convertible to PNG via the `!converter` dictionary. The underlying RMS is calculated - with the `.calculate_rms` function. + in a similar way to the `.calculate_rms` function. Parameters ---------- @@ -483,17 +483,12 @@ def compare_images(expected, actual, tol, in_decorator=False): if np.array_equal(expected_image, actual_image): return None - # convert to signed integers, so that the images can be subtracted without - # overflow - expected_image = expected_image.astype(np.int16) - actual_image = actual_image.astype(np.int16) - - rms = calculate_rms(expected_image, actual_image) + rms, abs_diff = _image.calculate_rms_and_diff(expected_image, actual_image) if rms <= tol: return None - save_diff_image(expected, actual, diff_image) + Image.fromarray(abs_diff).save(diff_image, format="png") results = dict(rms=rms, expected=str(expected), actual=str(actual), diff=str(diff_image), tol=tol) diff --git a/lib/matplotlib/tests/test_compare_images.py b/lib/matplotlib/tests/test_compare_images.py index 6023f3d05468..96b76f790ccd 100644 --- a/lib/matplotlib/tests/test_compare_images.py +++ b/lib/matplotlib/tests/test_compare_images.py @@ -1,11 +1,14 @@ from pathlib import Path import shutil +import numpy as np import pytest from pytest import approx +from matplotlib import _image from matplotlib.testing.compare import compare_images from matplotlib.testing.decorators import _image_directories +from matplotlib.testing.exceptions import ImageComparisonFailure # Tests of the image comparison algorithm. @@ -71,3 +74,27 @@ def test_image_comparison_expect_rms(im1, im2, tol, expect_rms, tmp_path, else: assert results is not None assert results['rms'] == approx(expect_rms, abs=1e-4) + + +def test_invalid_input(): + img = np.zeros((16, 16, 4), dtype=np.uint8) + + with pytest.raises(ImageComparisonFailure, + match='must be 3-dimensional, but is 2-dimensional'): + _image.calculate_rms_and_diff(img[:, :, 0], img) + with pytest.raises(ImageComparisonFailure, + match='must be 3-dimensional, but is 5-dimensional'): + _image.calculate_rms_and_diff(img, img[:, :, :, np.newaxis, np.newaxis]) + with pytest.raises(ImageComparisonFailure, + match='must be RGB or RGBA but has depth 2'): + _image.calculate_rms_and_diff(img[:, :, :2], img) + + with pytest.raises(ImageComparisonFailure, + match=r'expected size: \(16, 16, 4\) actual size \(8, 16, 4\)'): + _image.calculate_rms_and_diff(img, img[:8, :, :]) + with pytest.raises(ImageComparisonFailure, + match=r'expected size: \(16, 16, 4\) actual size \(16, 6, 4\)'): + _image.calculate_rms_and_diff(img, img[:, :6, :]) + with pytest.raises(ImageComparisonFailure, + match=r'expected size: \(16, 16, 4\) actual size \(16, 16, 3\)'): + _image.calculate_rms_and_diff(img, img[:, :, :3]) diff --git a/src/_image_wrapper.cpp b/src/_image_wrapper.cpp index 87d2b3b288ec..6528c4a9270c 100644 --- a/src/_image_wrapper.cpp +++ b/src/_image_wrapper.cpp @@ -1,6 +1,8 @@ #include #include +#include + #include "_image_resample.h" #include "py_converters.h" @@ -202,6 +204,80 @@ image_resample(py::array input_array, } +// This is used by matplotlib.testing.compare to calculate RMS and a difference image. +static py::tuple +calculate_rms_and_diff(py::array_t expected_image, + py::array_t actual_image) +{ + for (const auto & [image, name] : {std::pair{expected_image, "Expected"}, + std::pair{actual_image, "Actual"}}) + { + if (image.ndim() != 3) { + auto exceptions = py::module_::import("matplotlib.testing.exceptions"); + auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure"); + py::set_error( + ImageComparisonFailure, + "{name} image must be 3-dimensional, but is {ndim}-dimensional"_s.format( + "name"_a=name, "ndim"_a=image.ndim())); + throw py::error_already_set(); + } + } + + auto height = expected_image.shape(0); + auto width = expected_image.shape(1); + auto depth = expected_image.shape(2); + + if (depth != 3 && depth != 4) { + auto exceptions = py::module_::import("matplotlib.testing.exceptions"); + auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure"); + py::set_error( + ImageComparisonFailure, + "Image must be RGB or RGBA but has depth {depth}"_s.format( + "depth"_a=depth)); + throw py::error_already_set(); + } + + if (height != actual_image.shape(0) || width != actual_image.shape(1) || + depth != actual_image.shape(2)) { + auto exceptions = py::module_::import("matplotlib.testing.exceptions"); + auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure"); + py::set_error( + ImageComparisonFailure, + "Image sizes do not match expected size: {expected_image.shape} "_s + "actual size {actual_image.shape}"_s.format( + "expected_image"_a=expected_image, "actual_image"_a=actual_image)); + throw py::error_already_set(); + } + auto expected = expected_image.unchecked<3>(); + auto actual = actual_image.unchecked<3>(); + + py::ssize_t diff_dims[3] = {height, width, 3}; + py::array_t diff_image(diff_dims); + auto diff = diff_image.mutable_unchecked<3>(); + + double total = 0.0; + for (auto i = 0; i < height; i++) { + for (auto j = 0; j < width; j++) { + for (auto k = 0; k < depth; k++) { + auto pixel_diff = static_cast(expected(i, j, k)) - + static_cast(actual(i, j, k)); + + total += pixel_diff*pixel_diff; + + if (k != 3) { // Hard-code a fully solid alpha channel by omitting it. + diff(i, j, k) = static_cast(std::clamp( + abs(pixel_diff) * 10, // Expand differences in luminance domain. + 0.0, 255.0)); + } + } + } + } + total = total / (width * height * depth); + + return py::make_tuple(sqrt(total), diff_image); +} + + PYBIND11_MODULE(_image, m, py::mod_gil_not_used()) { py::enum_(m, "_InterpolationType") @@ -234,4 +310,7 @@ PYBIND11_MODULE(_image, m, py::mod_gil_not_used()) "norm"_a = false, "radius"_a = 1, image_resample__doc__); + + m.def("calculate_rms_and_diff", &calculate_rms_and_diff, + "expected_image"_a, "actual_image"_a); }