diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index bc7e4df568b4..6451a908baa6 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -2244,6 +2244,21 @@ def _reshape_2D(X): return X +def ensure_3d(arr): + """ + Return a version of arr with ndim==3, with extra dimensions added + at the end of arr.shape as needed. + """ + arr = np.asanyarray(arr) + if arr.ndim == 1: + arr = arr[:, None, None] + elif arr.ndim == 2: + arr = arr[:, :, None] + elif arr.ndim > 3 or arr.ndim < 1: + raise ValueError("cannot convert arr to 3-dimensional") + return arr + + def violin_stats(X, method, points=100): ''' Returns a list of dictionaries of data which can be used to draw a series diff --git a/lib/matplotlib/path.py b/lib/matplotlib/path.py index 2e0d86456167..ca80f4cbb6b7 100644 --- a/lib/matplotlib/path.py +++ b/lib/matplotlib/path.py @@ -24,7 +24,7 @@ from numpy import ma from matplotlib import _path -from matplotlib.cbook import simple_linear_interpolation, maxdict +from matplotlib.cbook import simple_linear_interpolation, maxdict, ensure_3d from matplotlib import rcParams @@ -988,7 +988,7 @@ def get_path_collection_extents( if len(paths) == 0: raise ValueError("No paths provided") return Bbox.from_extents(*_path.get_path_collection_extents( - master_transform, paths, np.atleast_3d(transforms), + master_transform, paths, ensure_3d(transforms), offsets, offset_transform)) diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 2b916b08566f..b244e761cf8f 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -376,3 +376,11 @@ def test_step_fails(): np.arange(12)) assert_raises(ValueError, cbook._step_validation, np.arange(12), np.arange(3)) + + +def test_ensure_3d(): + assert_array_equal([[[1]], [[2]], [[3]]], + cbook.ensure_3d([1, 2, 3])) + assert_array_equal([[[1], [2]], [[3], [4]]], + cbook.ensure_3d([[1, 2], [3, 4]])) + assert_raises(ValueError, cbook.ensure_3d, [[[[1]]]]) diff --git a/lib/matplotlib/transforms.py b/lib/matplotlib/transforms.py index 812047910e11..1eac27fc52ee 100644 --- a/lib/matplotlib/transforms.py +++ b/lib/matplotlib/transforms.py @@ -48,6 +48,7 @@ from sets import Set as set from .path import Path +from .cbook import ensure_3d DEBUG = False # we need this later, but this is very expensive to set up @@ -667,7 +668,7 @@ def count_overlaps(self, bboxes): bboxes is a sequence of :class:`BboxBase` objects """ return count_bboxes_overlapping_bbox( - self, np.atleast_3d([np.array(x) for x in bboxes])) + self, ensure_3d([np.array(x) for x in bboxes])) def expanded(self, sw, sh): """