From 3c345ec21ae119b99caa73126cd2d3ba2a643822 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Thu, 8 Nov 2018 21:08:57 +0100 Subject: [PATCH] Include scatter plots in Qt figure options editor. Essentially all the image-handling code can be reused; the only difference is that collections don't have an `interpolation` field so we need to check for that. The rest of the PR is just replacing "image" by "sm" ("ScalarMappable") throughout... --- .../backends/qt_editor/figureoptions.py | 69 ++++++++++--------- lib/matplotlib/tests/test_backend_qt.py | 1 + 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/lib/matplotlib/backends/qt_editor/figureoptions.py b/lib/matplotlib/backends/qt_editor/figureoptions.py index 6620f5870920..df308796da24 100644 --- a/lib/matplotlib/backends/qt_editor/figureoptions.py +++ b/lib/matplotlib/backends/qt_editor/figureoptions.py @@ -137,39 +137,43 @@ def prepare_data(d, init): # Is there a curve displayed? has_curve = bool(curves) - # Get / Images - imagedict = {} - for image in axes.get_images(): - label = image.get_label() - if label == '_nolegend_': + # Get ScalarMappables. + mappabledict = {} + for mappable in [*axes.images, *axes.collections]: + label = mappable.get_label() + if label == '_nolegend_' or mappable.get_array() is None: continue - imagedict[label] = image - imagelabels = sorted(imagedict, key=cmp_key) - images = [] + mappabledict[label] = mappable + mappablelabels = sorted(mappabledict, key=cmp_key) + mappables = [] cmaps = [(cmap, name) for name, cmap in sorted(cm.cmap_d.items())] - for label in imagelabels: - image = imagedict[label] - cmap = image.get_cmap() + for label in mappablelabels: + mappable = mappabledict[label] + cmap = mappable.get_cmap() if cmap not in cm.cmap_d.values(): - cmaps = [(cmap, cmap.name)] + cmaps - low, high = image.get_clim() - imagedata = [ + cmaps = [(cmap, cmap.name), *cmaps] + low, high = mappable.get_clim() + mappabledata = [ ('Label', label), ('Colormap', [cmap.name] + cmaps), ('Min. value', low), ('Max. value', high), - ('Interpolation', - [image.get_interpolation()] - + [(name, name) for name in sorted(mimage.interpolations_names)])] - images.append([imagedata, label, ""]) - # Is there an image displayed? - has_image = bool(images) + ] + if hasattr(mappable, "get_interpolation"): # Images. + interpolations = [ + (name, name) for name in sorted(mimage.interpolations_names)] + mappabledata.append(( + 'Interpolation', + [mappable.get_interpolation(), *interpolations])) + mappables.append([mappabledata, label, ""]) + # Is there a scalarmappable displayed? + has_sm = bool(mappables) datalist = [(general, "Axes", "")] if curves: datalist.append((curves, "Curves", "")) - if images: - datalist.append((images, "Images", "")) + if mappables: + datalist.append((mappables, "Images, etc.", "")) def apply_callback(data): """This function will be called to apply changes""" @@ -178,7 +182,7 @@ def apply_callback(data): general = data.pop(0) curves = data.pop(0) if has_curve else [] - images = data.pop(0) if has_image else [] + mappables = data.pop(0) if has_sm else [] if data: raise ValueError("Unexpected field") @@ -223,14 +227,17 @@ def apply_callback(data): line.set_markerfacecolor(markerfacecolor) line.set_markeredgecolor(markeredgecolor) - # Set / Images - for index, image_settings in enumerate(images): - image = imagedict[imagelabels[index]] - label, cmap, low, high, interpolation = image_settings - image.set_label(label) - image.set_cmap(cm.get_cmap(cmap)) - image.set_clim(*sorted([low, high])) - image.set_interpolation(interpolation) + # Set ScalarMappables. + for index, mappable_settings in enumerate(mappables): + mappable = mappabledict[mappablelabels[index]] + if len(mappable_settings) == 5: + label, cmap, low, high, interpolation = mappable_settings + mappable.set_interpolation(interpolation) + elif len(mappable_settings) == 4: + label, cmap, low, high = mappable_settings + mappable.set_label(label) + mappable.set_cmap(cm.get_cmap(cmap)) + mappable.set_clim(*sorted([low, high])) # re-generate legend, if checkbox is checked if generate_legend: diff --git a/lib/matplotlib/tests/test_backend_qt.py b/lib/matplotlib/tests/test_backend_qt.py index cecd7abcb175..5f8f0e50900d 100644 --- a/lib/matplotlib/tests/test_backend_qt.py +++ b/lib/matplotlib/tests/test_backend_qt.py @@ -250,6 +250,7 @@ def test_figureoptions(): fig, ax = plt.subplots() ax.plot([1, 2]) ax.imshow([[1]]) + ax.scatter(range(3), range(3), c=range(3)) with mock.patch( "matplotlib.backends.qt_editor.formlayout.FormDialog.exec_", lambda self: None):