diff --git a/lib/matplotlib/quiver.py b/lib/matplotlib/quiver.py index 9b63e7734198..e062ce1c817f 100644 --- a/lib/matplotlib/quiver.py +++ b/lib/matplotlib/quiver.py @@ -381,27 +381,49 @@ def contains(self, mouseevent): # This is a helper function that parses out the various combination of # arguments for doing colored vector plots. Pulling it out here # allows both Quiver and Barbs to use it -def _parse_args(*args): +def _parse_args(*args, **kw): X, Y, U, V, C = [None] * 5 - args = list(args) - - # The use of atleast_1d allows for handling scalar arguments while also - # keeping masked arrays - if len(args) == 3 or len(args) == 5: - C = np.atleast_1d(args.pop(-1)) - V = np.atleast_1d(args.pop(-1)) - U = np.atleast_1d(args.pop(-1)) - if U.ndim == 1: - nr, nc = 1, U.shape[0] - else: - nr, nc = U.shape - if len(args) == 2: # remaining after removing U,V,C - X, Y = [np.array(a).ravel() for a in args] - if len(X) == nc and len(Y) == nr: - X, Y = [a.ravel() for a in np.meshgrid(X, Y)] - else: - indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) - X, Y = [np.ravel(a) for a in indexgrid] + args = list(args) # The use of atleast_1d allows for handling + if len(args) != 0: # scalar arguments while also + if len(args) == 3 or len(args) == 5: # keeping masked arrays + C = np.atleast_1d(args.pop(-1)) + V = np.atleast_1d(args.pop(-1)) + U = np.atleast_1d(args.pop(-1)) + if U.ndim == 1: + nr, nc = 1, U.shape[0] + else: + nr, nc = U.shape + if len(args) == 2: # remaining after removing U,V,C + X, Y = [np.array(a).ravel() for a in args] + if len(X) == nc and len(Y) == nr: + X, Y = [a.ravel() for a in np.meshgrid(X, Y)] + else: + indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) + X, Y = [np.ravel(a) for a in indexgrid] + if len(kw) != 0: # some keyword arguments + # The use of atleast_1d allows for handling scalar arguments while also + # keeping masked arrays + if len(kw) == 3 or len(kw) == 5: + if (kw.get('C') is not None): + C = np.atleast_1d(kw.pop('C')) + if (kw.get('V') is not None): + V = np.atleast_1d(kw.pop('V')) + if (kw.get('U') is not None): + U = np.atleast_1d(kw.pop('U')) + if U.ndim == 1: + nr, nc = 1, U.shape[0] + else: + nr, nc = U.shape + if len(kw) == 2: # remaining after removing U,V,C. CASE 1 + if kw.get('X') is not None: + X = np.array(kw.get('X')).ravel() + if kw.get('Y') is not None: + Y = np.array(kw.get('Y')).ravel() + if len(X) == nc and len(Y) == nr: + X, Y = [a.ravel() for a in np.meshgrid(X, Y)] + elif len(kw) == 0: # CASE 2 + indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) + X, Y = [np.ravel(a) for a in indexgrid] return X, Y, U, V, C @@ -443,7 +465,17 @@ def __init__(self, ax, *args, %s """ self.ax = ax - X, Y, U, V, C = _parse_args(*args) + X, Y, U, V, C = _parse_args(*args, **kw) + if kw.get('U') is not None: # Resetting **kw to the + kw.pop('U') # way it was without these + if kw.get('V') is not None: + kw.pop('V') + if kw.get('X') is not None: + kw.pop('X') + if kw.get('Y') is not None: + kw.pop('Y') + if kw.get('C') is not None: + kw.pop('C') self.X = X self.Y = Y self.XY = np.column_stack((X, Y)) @@ -950,7 +982,17 @@ def __init__(self, ax, *args, kw['linewidth'] = 1 # Parse out the data arrays from the various configurations supported - x, y, u, v, c = _parse_args(*args) + x, y, u, v, c = _parse_args(*args, **kw) + if kw.get('U') is not None: # Resetting **kw to the way + kw.pop('U') # it was without these + if kw.get('V') is not None: + kw.pop('V') + if kw.get('X') is not None: + kw.pop('X') + if kw.get('Y') is not None: + kw.pop('Y') + if kw.get('C') is not None: + kw.pop('C') self.x = x self.y = y xy = np.column_stack((x, y)) diff --git a/lib/matplotlib/tests/test_quiver.py b/lib/matplotlib/tests/test_quiver.py index 4470e02fac8c..eaaa7857df47 100644 --- a/lib/matplotlib/tests/test_quiver.py +++ b/lib/matplotlib/tests/test_quiver.py @@ -186,6 +186,26 @@ def test_quiver_xy(): ax.grid() +def test_quiver_keyword_arguments(): + ax = plt.axes() + x, y = np.arange(8), np.arange(10) + u = v = np.linspace(0, 10, 80).reshape(10, 8) + q = plt.quiver(X=x, Y=y, U=u, V=v) + assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + + ax = plt.axes() + x, y = np.arange(8), np.arange(10) + u = v = np.linspace(0, 10, 80).reshape(10, 8) + q = plt.quiver(x, y, U=u, V=v) + assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + + ax = plt.axes() + x, y = np.arange(8), np.arange(10) + u = v = np.linspace(0, 10, 80).reshape(10, 8) + q = plt.quiver(X=x, Y=y, U=u, V=v, C=(1, 1, 1)) + assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + + def test_quiverkey_angles(): # Check that only a single arrow is plotted for a quiverkey when an array # of angles is given to the original quiver plot