Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7ede153

Browse files
committed
Take better advantage of numpy in quiver
svn path=/trunk/matplotlib/; revision=3400
1 parent 9ce996d commit 7ede153

File tree

1 file changed

+41
-36
lines changed

1 file changed

+41
-36
lines changed

lib/matplotlib/quiver.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
If U and V are 2-D arrays but X and Y are 1-D, and if
3333
len(X) and len(Y) match the column and row dimensions
3434
of U, then X and Y will be expanded with meshgrid.
35+
U, V, C may be masked arrays, but masked X, Y are not
36+
supported at present.
3537
3638
Keyword arguments (default given first):
3739
@@ -298,8 +300,7 @@ def _parse_args(self, *args):
298300
X, Y, U, V, C = [None]*5
299301
args = list(args)
300302
if len(args) == 3 or len(args) == 5:
301-
C = npy.ravel(args.pop(-1))
302-
#print 'in parse_args, C:', C
303+
C = ma.asarray(args.pop(-1)).ravel()
303304
V = ma.asarray(args.pop(-1))
304305
U = ma.asarray(args.pop(-1))
305306
nn = npy.shape(U)
@@ -308,9 +309,9 @@ def _parse_args(self, *args):
308309
if len(nn) > 1:
309310
nr = nn[1]
310311
if len(args) == 2:
311-
X, Y = [npy.ravel(a) for a in args]
312+
X, Y = [npy.array(a).ravel() for a in args]
312313
if len(X) == nc and len(Y) == nr:
313-
X, Y = [npy.ravel(a) for a in npy.meshgrid(X, Y)]
314+
X, Y = [a.ravel() for a in npy.meshgrid(X, Y)]
314315
else:
315316
indexgrid = npy.meshgrid(npy.arange(nc), npy.arange(nr))
316317
X, Y = [npy.ravel(a) for a in indexgrid]
@@ -333,15 +334,20 @@ def draw(self, renderer):
333334
self._init()
334335
if self._new_UV:
335336
verts = self._make_verts(self.U, self.V)
336-
self.set_verts(verts)
337+
# Using nan internally here is the easiest
338+
# way to support masked inputs; it doesn't
339+
# require adding mask support to PolyCollection,
340+
# and it keeps all array dimensions (X, Y, U, V, C)
341+
# intact.
342+
self.set_verts(verts.filled(npy.nan))
337343
self._new_UV = False
338344
collections.PolyCollection.draw(self, renderer)
339345

340346
def set_UVC(self, U, V, C=None):
341-
self.U = ma.ravel(U)
342-
self.V = ma.ravel(V)
347+
self.U = U.ravel()
348+
self.V = V.ravel()
343349
if C is not None:
344-
self.set_array(npy.ravel(C))
350+
self.set_array(C.ravel())
345351
self._new_UV = True
346352

347353
def _set_transform(self):
@@ -371,75 +377,74 @@ def _set_transform(self):
371377
return trans
372378

373379
def _make_verts(self, U, V):
374-
uv = U+V*1j
375-
uv = npy.ravel(ma.filled(uv,npy.nan))
376-
a = npy.absolute(uv)
380+
uv = ma.asarray(U+V*1j)
381+
a = ma.absolute(uv)
377382
if self.scale is None:
378383
sn = max(10, math.sqrt(self.N))
379-
380-
# get valid values for average
381-
# (complicated by support for 3 array packages)
382-
a_valid_cond = ~npy.isnan(a)
383-
a_valid_idx = npy.nonzero(a_valid_cond)
384-
if isinstance(a_valid_idx,tuple):
385-
# numpy.nonzero returns tuple
386-
a_valid_idx = a_valid_idx[0]
387-
valid_a = npy.take(a,a_valid_idx)
388-
389-
scale = 1.8 * npy.average(valid_a) * sn # crude auto-scaling
390-
scale = scale/self.span
384+
scale = 1.8 * a.mean() * sn / self.span # crude auto-scaling
391385
self.scale = scale
392386
length = a/(self.scale*self.width)
393387
X, Y = self._h_arrows(length)
394-
xy = (X+Y*1j) * npy.exp(1j*npy.angle(uv[...,npy.newaxis]))*self.width
388+
# There seems to be a ma bug such that indexing
389+
# a masked array with one element converts it to
390+
# an ndarray.
391+
theta = npy.angle(ma.asarray(uv[..., npy.newaxis]).filled(0))
392+
xy = (X+Y*1j) * npy.exp(1j*theta)*self.width
395393
xy = xy[:,:,npy.newaxis]
396-
XY = npy.concatenate((xy.real, xy.imag), axis=2)
394+
XY = ma.concatenate((xy.real, xy.imag), axis=2)
397395
return XY
398396

399397

400398
def _h_arrows(self, length):
401399
""" length is in arrow width units """
400+
# It might be possible to streamline the code
401+
# and speed it up a bit by using complex (x,y)
402+
# instead of separate arrays; but any gain would be slight.
402403
minsh = self.minshaft * self.headlength
403404
N = len(length)
404-
length = npy.reshape(length, (N,1))
405+
length = length.reshape(N, 1)
406+
# x, y: normal horizontal arrow
405407
x = npy.array([0, -self.headaxislength,
406408
-self.headlength, 0], npy.float64)
407409
x = x + npy.array([0,1,1,1]) * length
408410
y = 0.5 * npy.array([1, 1, self.headwidth, 0], npy.float64)
409411
y = npy.repeat(y[npy.newaxis,:], N, axis=0)
412+
# x0, y0: arrow without shaft, for short vectors
410413
x0 = npy.array([0, minsh-self.headaxislength,
411414
minsh-self.headlength, minsh], npy.float64)
412415
y0 = 0.5 * npy.array([1, 1, self.headwidth, 0], npy.float64)
413416
ii = [0,1,2,3,2,1,0]
414-
X = npy.take(x, ii, 1)
415-
Y = npy.take(y, ii, 1)
417+
X = x.take(ii, 1)
418+
Y = y.take(ii, 1)
416419
Y[:, 3:] *= -1
417-
X0 = npy.take(x0, ii)
418-
Y0 = npy.take(y0, ii)
420+
X0 = x0.take(ii)
421+
Y0 = y0.take(ii)
419422
Y0[3:] *= -1
420423
shrink = length/minsh
421424
X0 = shrink * X0[npy.newaxis,:]
422425
Y0 = shrink * Y0[npy.newaxis,:]
423426
short = npy.repeat(length < minsh, 7, axis=1)
424427
#print 'short', length < minsh
425-
X = npy.where(short, X0, X)
426-
Y = npy.where(short, Y0, Y)
428+
# Now select X0, Y0 if short, otherwise X, Y
429+
X = ma.where(short, X0, X)
430+
Y = ma.where(short, Y0, Y)
427431
if self.pivot[:3] == 'mid':
428432
X -= 0.5 * X[:,3, npy.newaxis]
429433
elif self.pivot[:3] == 'tip':
430434
X = X - X[:,3, npy.newaxis] #numpy bug? using -= does not
431435
# work here unless we multiply
432436
# by a float first, as with 'mid'.
433437
tooshort = length < self.minlength
434-
if npy.any(tooshort):
438+
if tooshort.any():
439+
# Use a heptagonal dot:
435440
th = npy.arange(0,7,1, npy.float64) * (npy.pi/3.0)
436441
x1 = npy.cos(th) * self.minlength * 0.5
437442
y1 = npy.sin(th) * self.minlength * 0.5
438443
X1 = npy.repeat(x1[npy.newaxis, :], N, axis=0)
439444
Y1 = npy.repeat(y1[npy.newaxis, :], N, axis=0)
440-
tooshort = npy.repeat(tooshort, 7, 1)
441-
X = npy.where(tooshort, X1, X)
442-
Y = npy.where(tooshort, Y1, Y)
445+
tooshort = ma.repeat(tooshort, 7, 1)
446+
X = ma.where(tooshort, X1, X)
447+
Y = ma.where(tooshort, Y1, Y)
443448
return X, Y
444449

445450
quiver_doc = _quiver_doc

0 commit comments

Comments
 (0)