diff --git a/lib/matplotlib/quiver.py b/lib/matplotlib/quiver.py index 9b63e7734198..6f080e94c874 100644 --- a/lib/matplotlib/quiver.py +++ b/lib/matplotlib/quiver.py @@ -18,7 +18,6 @@ import weakref import numpy as np - from numpy import ma import matplotlib.collections as mcollections import matplotlib.transforms as transforms @@ -381,27 +380,60 @@ def contains(self, mouseevent): # This is a helper function that parses out the various combination of # arguments for doing colored vector plots. Pulling it out here # allows both Quiver and Barbs to use it -def _parse_args(*args): +def _parse_args(*args, **kw): + """ + >>> print(_parse_args(1, 1, 2, 2, (0, 0, 0))) + (array([1]), array([1]), array([2]), array([2]), array([0, 0, 0])) + >>> _parse_args(2, 2, 1) + (array([0]), array([0]), array([2]), array([2]), array([1])) + + >>> print(_parse_args(X = 1, Y = 1, U = 2, V = 2, C = (0, 0, 0))) + (array([0]), array([0]), array([2]), array([2]), array([0, 0, 0])) + >>> _parse_args(U = 2, V = 2, C = 1) + (array([0]), array([0]), array([2]), array([2]), array([1])) + + """ X, Y, U, V, C = [None] * 5 - args = list(args) - - # The use of atleast_1d allows for handling scalar arguments while also - # keeping masked arrays - if len(args) == 3 or len(args) == 5: - C = np.atleast_1d(args.pop(-1)) - V = np.atleast_1d(args.pop(-1)) - U = np.atleast_1d(args.pop(-1)) - if U.ndim == 1: - nr, nc = 1, U.shape[0] - else: - nr, nc = U.shape - if len(args) == 2: # remaining after removing U,V,C - X, Y = [np.array(a).ravel() for a in args] - if len(X) == nc and len(Y) == nr: - X, Y = [a.ravel() for a in np.meshgrid(X, Y)] - else: - indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) - X, Y = [np.ravel(a) for a in indexgrid] + if len(kw) == 0:#No keyword arguments. + args = list(args) + # The use of atleast_1d allows for handling scalar arguments while also + # keeping masked arrays + if len(args) == 3 or len(args) == 5: + C = np.atleast_1d(args.pop(-1)) + V = np.atleast_1d(args.pop(-1)) + U = np.atleast_1d(args.pop(-1)) + if U.ndim == 1: + nr, nc = 1, U.shape[0] + else: + nr, nc = U.shape + if len(args) == 2: # remaining after removing U,V,C CASE-1 + X, Y = [np.array(a).ravel() for a in args] + if len(X) == nc and len(Y) == nr: + X, Y = [a.ravel() for a in np.meshgrid(X, Y)] + else:#CASE 2 + indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) + X, Y = [np.ravel(a) for a in indexgrid] + + elif len(args) == 0:#only keyword arguments + # The use of atleast_1d allows for handling scalar arguments while also + # keeping masked arrays + C = np.atleast_1d(kw.get('C')) + V = np.atleast_1d(kw.get('V')) + U = np.atleast_1d(kw.get('U')) + if U.ndim == 1: + nr, nc = 1, U.shape[0] + else: + nr, nc = U.shape + + X = np.array(kw.get('X')).ravel() + Y = np.array(kw.get('Y')).ravel() + if (X == [None] and Y == [None]) or (X != [None] and Y != [None]):#CASE 2 of *args + indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) + X, Y = [np.ravel(a) for a in indexgrid] + else:#CASE 1 of *args + if len(X) == nc and len(Y) == nr: + X, Y = [a.ravel() for a in np.meshgrid(X, Y)] + return X, Y, U, V, C @@ -442,8 +474,9 @@ def __init__(self, ax, *args, by the following pyplot interface documentation: %s """ + self.ax = ax - X, Y, U, V, C = _parse_args(*args) + X, Y, U, V, C = _parse_args(*args, **kw) self.X = X self.Y = Y self.XY = np.column_stack((X, Y))