diff --git a/doc/users/whats_new.rst b/doc/users/whats_new.rst index 362384db8c94..525118696633 100644 --- a/doc/users/whats_new.rst +++ b/doc/users/whats_new.rst @@ -246,6 +246,13 @@ volumetric model. Improvements ++++++++++++ +Add ``capstyle`` and ``joinstyle`` attributes to `Collection` +------------------------------------------------------------- + +The `Collection` class now has customizable ``capstyle`` and ``joinstyle`` +attributes. This allows the user for example to set the ``capstyle`` of +errorbars. + CheckButtons widget ``get_status`` function ------------------------------------------- diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index ad660e97ecbd..3009c3d45355 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -52,11 +52,16 @@ class Collection(artist.Artist, cm.ScalarMappable): prop[i % len(props)] + Exceptions are *capstyle* and *joinstyle* properties, these can + only be set globally for the whole collection. + Keyword arguments and default values: * *edgecolors*: None * *facecolors*: None * *linewidths*: None + * *capstyle*: None + * *joinstyle*: None * *antialiaseds*: None * *offsets*: None * *transOffset*: transforms.IdentityTransform() @@ -104,6 +109,8 @@ def __init__(self, facecolors=None, linewidths=None, linestyles='solid', + capstyle=None, + joinstyle=None, antialiaseds=None, offsets=None, transOffset=None, @@ -145,6 +152,16 @@ def __init__(self, self.set_offset_position(offset_position) self.set_zorder(zorder) + if capstyle: + self.set_capstyle(capstyle) + else: + self._capstyle = None + + if joinstyle: + self.set_joinstyle(joinstyle) + else: + self._joinstyle = None + self._offsets = np.zeros((1, 2)) self._uniform_offsets = None if offsets is not None: @@ -304,6 +321,12 @@ def draw(self, renderer): extents.height < height): do_single_path_optimization = True + if self._joinstyle: + gc.set_joinstyle(self._joinstyle) + + if self._capstyle: + gc.set_capstyle(self._capstyle) + if do_single_path_optimization: gc.set_foreground(tuple(edgecolors[0])) gc.set_linewidth(self._linewidths[0]) @@ -536,6 +559,42 @@ def set_linestyle(self, ls): self._linewidths, self._linestyles = self._bcast_lwls( self._us_lw, self._us_linestyles) + def set_capstyle(self, cs): + """ + Set the capstyle for the collection. The capstyle can + only be set globally for all elements in the collection + + Parameters + ---------- + cs : ['butt' | 'round' | 'projecting'] + The capstyle + """ + if cs in ('butt', 'round', 'projecting'): + self._capstyle = cs + else: + raise ValueError('Unrecognized cap style. Found %s' % cs) + + def get_capstyle(self): + return self._capstyle + + def set_joinstyle(self, js): + """ + Set the joinstyle for the collection. The joinstyle can only be + set globally for all elements in the collection. + + Parameters + ---------- + js : ['miter' | 'round' | 'bevel'] + The joinstyle + """ + if js in ('miter', 'round', 'bevel'): + self._joinstyle = js + else: + raise ValueError('Unrecognized join style. Found %s' % js) + + def get_joinstyle(self): + return self._joinstyle + @staticmethod def _bcast_lwls(linewidths, dashes): '''Internal helper function to broadcast + scale ls/lw diff --git a/lib/matplotlib/tests/baseline_images/test_collections/cap_and_joinstyle.png b/lib/matplotlib/tests/baseline_images/test_collections/cap_and_joinstyle.png new file mode 100644 index 000000000000..ef18c1311d89 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_collections/cap_and_joinstyle.png differ diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index c27aeb11a159..03d5c80b38e9 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -14,7 +14,7 @@ import matplotlib.pyplot as plt import matplotlib.collections as mcollections import matplotlib.transforms as mtransforms -from matplotlib.collections import Collection, EventCollection +from matplotlib.collections import Collection, LineCollection, EventCollection from matplotlib.testing.decorators import image_comparison @@ -626,6 +626,44 @@ def test_lslw_bcast(): assert_equal(col.get_linewidths(), [1, 2, 3]) +@pytest.mark.style('default') +def test_capstyle(): + col = mcollections.PathCollection([], capstyle='round') + assert_equal(col.get_capstyle(), 'round') + col.set_capstyle('butt') + assert_equal(col.get_capstyle(), 'butt') + + +@pytest.mark.style('default') +def test_joinstyle(): + col = mcollections.PathCollection([], joinstyle='round') + assert_equal(col.get_joinstyle(), 'round') + col.set_joinstyle('miter') + assert_equal(col.get_joinstyle(), 'miter') + + +@image_comparison(baseline_images=['cap_and_joinstyle'], + extensions=['png']) +def test_cap_and_joinstyle_image(): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + ax.set_xlim([-0.5, 1.5]) + ax.set_ylim([-0.5, 2.5]) + + x = np.array([0.0, 1.0, 0.5]) + ys = np.array([[0.0], [0.5], [1.0]]) + np.array([[0.0, 0.0, 1.0]]) + + segs = np.zeros((3, 3, 2)) + segs[:, :, 0] = x + segs[:, :, 1] = ys + line_segments = LineCollection(segs, linewidth=[10, 15, 20]) + line_segments.set_capstyle("round") + line_segments.set_joinstyle("miter") + + ax.add_collection(line_segments) + ax.set_title('Line collection with customized caps and joinstyle') + + @image_comparison(baseline_images=['scatter_post_alpha'], extensions=['png'], remove_text=True, style='default')