From d0d2f5f4d31d3c622a0c41b88335dd166549914d Mon Sep 17 00:00:00 2001 From: Faruk Fakih Date: Wed, 28 Sep 2022 11:47:11 +0100 Subject: [PATCH] Created handler for PatchCollection --- .../patch_collection_handler.rst | 24 ++++++++ lib/matplotlib/legend.py | 3 +- lib/matplotlib/legend_handler.py | 60 ++++++++++++++----- lib/matplotlib/legend_handler.pyi | 13 ++++ lib/matplotlib/tests/test_legend.py | 22 +++++++ 5 files changed, 105 insertions(+), 17 deletions(-) create mode 100644 doc/users/next_whats_new/patch_collection_handler.rst diff --git a/doc/users/next_whats_new/patch_collection_handler.rst b/doc/users/next_whats_new/patch_collection_handler.rst new file mode 100644 index 000000000000..cce9912884e1 --- /dev/null +++ b/doc/users/next_whats_new/patch_collection_handler.rst @@ -0,0 +1,24 @@ +Legend handler for PatchCollection objects +------------------------------------------ + +PatchCollection objects are now supported in legends. The feature can be used as follows: + +.. plot:: + :include-source: true + + import matplotlib.pyplot as plt + from matplotlib.collections import PatchCollection + from matplotlib.patches import Polygon + + fig, axs = plt.subplots() + p1, p2 = Polygon([[0, 0], [100, 100], [200, 0]]), Polygon([[400, 0], [500, 100], [600, 0]]) + p3, p4 = Polygon([[700, 0], [800, 100], [900, 0]]), Polygon([[1000, 0], [1100, 100], [1200, 0]]) + p = PatchCollection([p1, p2], label="a", facecolors='red', edgecolors='black') + p2 = PatchCollection([p3, p4], label="ab", color='green') + axs.add_collection(p, autolim=True) + axs.add_collection(p2, autolim=True) + axs.set_xlim(right=1200) + axs.set_ylim(top=100) + axs.legend() + + plt.show() diff --git a/lib/matplotlib/legend.py b/lib/matplotlib/legend.py index 93ec3d32c0b7..0dd34f03be8e 100644 --- a/lib/matplotlib/legend.py +++ b/lib/matplotlib/legend.py @@ -38,7 +38,7 @@ StepPatch) from matplotlib.collections import ( Collection, CircleCollection, LineCollection, PathCollection, - PolyCollection, RegularPolyCollection) + PolyCollection, PatchCollection, RegularPolyCollection) from matplotlib.text import Text from matplotlib.transforms import Bbox, BboxBase, TransformedBbox from matplotlib.transforms import BboxTransformTo, BboxTransformFrom @@ -792,6 +792,7 @@ def draw(self, renderer): Patch: legend_handler.HandlerPatch(), StepPatch: legend_handler.HandlerStepPatch(), LineCollection: legend_handler.HandlerLineCollection(), + PatchCollection: legend_handler.HandlerPatchCollection(), RegularPolyCollection: legend_handler.HandlerRegularPolyCollection(), CircleCollection: legend_handler.HandlerCircleCollection(), BarContainer: legend_handler.HandlerPatch( diff --git a/lib/matplotlib/legend_handler.py b/lib/matplotlib/legend_handler.py index 5a929070e32d..656c39b76ffd 100644 --- a/lib/matplotlib/legend_handler.py +++ b/lib/matplotlib/legend_handler.py @@ -43,6 +43,19 @@ def update_from_first_child(tgt, src): tgt.update_from(first_child) +def _first_color(colors): + if colors.size == 0: + return (0, 0, 0, 0) + return tuple(colors[0]) + + +def _get_first(prop_array): + if len(prop_array): + return prop_array[0] + else: + return None + + class HandlerBase: """ A base class for default legend handlers. @@ -427,6 +440,32 @@ def create_artists(self, legend, orig_handle, return [legline] +class HandlerPatchCollection(HandlerPatch): + """ + Handler for `.PatchCollection` instances. + """ + def _default_update_prop(self, legend_handle, orig_handle): + lw = _get_first(orig_handle.get_linewidths()) + dashes = _get_first(orig_handle._us_linestyles) + facecolor = _first_color(orig_handle.get_facecolor()) + edgecolor = _first_color(orig_handle.get_edgecolor()) + legend_handle.set_facecolor(facecolor) + legend_handle.set_edgecolor(edgecolor) + legend_handle.set_linestyle(dashes) + legend_handle.set_linewidth(lw) + + def create_artists(self, legend, orig_handle, + xdescent, ydescent, width, height, fontsize, trans): + + p = self._create_patch(legend, orig_handle, + xdescent, ydescent, width, height, fontsize) + + self.update_prop(p, orig_handle, legend) + p.set_transform(trans) + + return [p] + + class HandlerRegularPolyCollection(HandlerNpointsYoffsets): r"""Handler for `.RegularPolyCollection`\s.""" @@ -775,21 +814,10 @@ class HandlerPolyCollection(HandlerBase): `~.Axes.stackplot`. """ def _update_prop(self, legend_handle, orig_handle): - def first_color(colors): - if colors.size == 0: - return (0, 0, 0, 0) - return tuple(colors[0]) - - def get_first(prop_array): - if len(prop_array): - return prop_array[0] - else: - return None - # orig_handle is a PolyCollection and legend_handle is a Patch. # Directly set Patch color attributes (must be RGBA tuples). - legend_handle._facecolor = first_color(orig_handle.get_facecolor()) - legend_handle._edgecolor = first_color(orig_handle.get_edgecolor()) + legend_handle._facecolor = _first_color(orig_handle.get_facecolor()) + legend_handle._edgecolor = _first_color(orig_handle.get_edgecolor()) legend_handle._original_facecolor = orig_handle._original_facecolor legend_handle._original_edgecolor = orig_handle._original_edgecolor legend_handle._fill = orig_handle.get_fill() @@ -797,9 +825,9 @@ def get_first(prop_array): # Hatch color is anomalous in having no getters and setters. legend_handle._hatch_color = orig_handle._hatch_color # Setters are fine for the remaining attributes. - legend_handle.set_linewidth(get_first(orig_handle.get_linewidths())) - legend_handle.set_linestyle(get_first(orig_handle.get_linestyles())) - legend_handle.set_transform(get_first(orig_handle.get_transforms())) + legend_handle.set_linewidth(_get_first(orig_handle.get_linewidths())) + legend_handle.set_linestyle(_get_first(orig_handle.get_linestyles())) + legend_handle.set_transform(_get_first(orig_handle.get_transforms())) legend_handle.set_figure(orig_handle.get_figure()) # Alpha is already taken into account by the color attributes. diff --git a/lib/matplotlib/legend_handler.pyi b/lib/matplotlib/legend_handler.pyi index db028a136a48..caf989ded91e 100644 --- a/lib/matplotlib/legend_handler.pyi +++ b/lib/matplotlib/legend_handler.pyi @@ -144,6 +144,19 @@ class HandlerLineCollection(HandlerLine2D): trans: Transform, ) -> Sequence[Artist]: ... +class HandlerPatchCollection(HandlerPatch): + def create_artists( + self, + legend: Legend, + orig_handle: Artist, + xdescent: float, + ydescent: float, + width: float, + height: float, + fontsize: float, + trans: Transform, + ) -> Sequence[Artist]: ... + _T = TypeVar("_T", bound=Artist) class HandlerRegularPolyCollection(HandlerNpointsYoffsets): diff --git a/lib/matplotlib/tests/test_legend.py b/lib/matplotlib/tests/test_legend.py index 90b0a3f38999..5956f6b06918 100644 --- a/lib/matplotlib/tests/test_legend.py +++ b/lib/matplotlib/tests/test_legend.py @@ -625,6 +625,28 @@ def test_linecollection_scaled_dashes(): assert oh.get_linestyles()[0] == lh._dash_pattern +def test_patch_collection_handler(): + fig, ax = plt.subplots() + pc = mcollections.PatchCollection([ + plt.Circle((0, 0), radius=1, facecolor='red', edgecolor='green', + linewidth=3, linestyle='--'), + plt.Rectangle((0.5, 0.5), 1, 1), + ], match_original=True, label='my_collection') + + ax.add_collection(pc) + _, labels = ax.get_legend_handles_labels() + assert len(labels) == 1 + assert labels[0] == 'my_collection' + + leg = ax.legend() + handles = leg.legend_handles + assert mpl.colors.same_color(handles[0].get_facecolor(), 'red') + assert mpl.colors.same_color(handles[0].get_edgecolor(), 'green') + assert handles[0].get_linewidth() == 3 + np.testing.assert_allclose(handles[0].get_linestyle()[1], + pc.get_linestyle()[0][1]) + + def test_handler_numpoints(): """Test legend handler with numpoints <= 1.""" # related to #6921 and PR #8478