Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 797d499

Browse files
NelleVdmcdougall
authored andcommitted
PEP8 fixes on streamplot.py
1 parent c54d158 commit 797d499

File tree

1 file changed

+39
-33
lines changed

1 file changed

+39
-33
lines changed

lib/matplotlib/streamplot.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
4949
Minimum length of streamline in axes coordinates.
5050
5151
Returns:
52-
52+
5353
*stream_container* : StreamplotSet
5454
Container object with attributes
55-
lines : `matplotlib.collections.LineCollection` of streamlines
56-
arrows : collection of `matplotlib.patches.FancyArrowPatch` objects
57-
repesenting arrows half-way along stream lines.
58-
This container will probably change in the future to allow changes to
59-
the colormap, alpha, etc. for both lines and arrows, but these changes
60-
should be backward compatible.
61-
55+
lines: `matplotlib.collections.LineCollection` of streamlines
56+
arrows: collection of `matplotlib.patches.FancyArrowPatch`
57+
objects representing arrows half-way along stream
58+
lines.
59+
This container will probably change in the future to allow changes
60+
to the colormap, alpha, etc. for both lines and arrows, but these
61+
changes should be backward compatible.
62+
6263
"""
6364
grid = Grid(x, y)
6465
mask = StreamMask(density)
@@ -71,7 +72,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
7172
linewidth = matplotlib.rcParams['lines.linewidth']
7273

7374
line_kw = {}
74-
arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10*arrowsize)
75+
arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize)
7576

7677
use_multicolor_lines = isinstance(color, np.ndarray)
7778
if use_multicolor_lines:
@@ -104,7 +105,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
104105
if mask[ym, xm] == 0:
105106
xg, yg = dmap.mask2grid(xm, ym)
106107
t = integrate(xg, yg)
107-
if t != None:
108+
if t is not None:
108109
trajectories.append(t)
109110

110111
if use_multicolor_lines:
@@ -128,10 +129,10 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
128129
streamlines.extend(np.hstack([points[:-1], points[1:]]))
129130

130131
# Add arrows half way along each trajectory.
131-
s = np.cumsum(np.sqrt(np.diff(tx)**2 + np.diff(ty)**2))
132+
s = np.cumsum(np.sqrt(np.diff(tx) ** 2 + np.diff(ty) ** 2))
132133
n = np.searchsorted(s, s[-1] / 2.)
133134
arrow_tail = (tx[n], ty[n])
134-
arrow_head = (np.mean(tx[n:n+2]), np.mean(ty[n:n+2]))
135+
arrow_head = (np.mean(tx[n:n + 2]), np.mean(ty[n:n + 2]))
135136

136137
if isinstance(linewidth, np.ndarray):
137138
line_widths = interpgrid(linewidth, tgx, tgy)[:-1]
@@ -143,15 +144,15 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
143144
line_colors.extend(color_values)
144145
arrow_kw['color'] = cmap(norm(color_values[n]))
145146

146-
p = patches.FancyArrowPatch(arrow_tail,
147-
arrow_head,
148-
transform=transform,
147+
p = patches.FancyArrowPatch(arrow_tail,
148+
arrow_head,
149+
transform=transform,
149150
**arrow_kw)
150151
axes.add_patch(p)
151152
arrows.append(p)
152153

153-
lc = mcollections.LineCollection(streamlines,
154-
transform=transform,
154+
lc = mcollections.LineCollection(streamlines,
155+
transform=transform,
155156
**line_kw)
156157
if use_multicolor_lines:
157158
lc.set_array(np.asarray(line_colors))
@@ -275,7 +276,7 @@ def within_grid(self, xi, yi):
275276
"""Return True if point is a valid index of grid."""
276277
# Note that xi/yi can be floats; so, for example, we can't simply check
277278
# `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx`
278-
return xi >= 0 and xi <= self.nx-1 and yi >= 0 and yi <= self.ny-1
279+
return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1
279280

280281

281282
class StreamMask(object):
@@ -330,6 +331,7 @@ def _update_trajectory(self, xm, ym):
330331
class InvalidIndexError(Exception):
331332
pass
332333

334+
333335
class TerminateTrajectory(Exception):
334336
pass
335337

@@ -345,7 +347,7 @@ def get_integrator(u, v, dmap, minlength):
345347
# speed (path length) will be in axes-coordinates
346348
u_ax = u / dmap.grid.nx
347349
v_ax = v / dmap.grid.ny
348-
speed = np.ma.sqrt(u_ax**2 + v_ax**2)
350+
speed = np.ma.sqrt(u_ax ** 2 + v_ax ** 2)
349351

350352
def forward_time(xi, yi):
351353
ds_dt = interpgrid(speed, xi, yi)
@@ -382,7 +384,7 @@ def integrate(x0, y0):
382384

383385
if stotal > minlength:
384386
return x_traj, y_traj
385-
else: # reject short trajectories
387+
else: # reject short trajectories
386388
dmap.undo_trajectory()
387389
return None
388390

@@ -423,7 +425,7 @@ def _integrate_rk12(x0, y0, dmap, f):
423425
## increment the location gradually. However, due to the efficient
424426
## nature of the interpolation, this doesn't boost speed by much
425427
## for quite a bit of complexity.
426-
maxds = min(1./dmap.mask.nx, 1./dmap.mask.ny, 0.1)
428+
maxds = min(1. / dmap.mask.nx, 1. / dmap.mask.ny, 0.1)
427429

428430
ds = maxds
429431
stotal = 0
@@ -455,7 +457,7 @@ def _integrate_rk12(x0, y0, dmap, f):
455457

456458
nx, ny = dmap.grid.shape
457459
# Error is normalized to the axes coordinates
458-
error = np.sqrt(((dx2-dx1)/nx)**2 + ((dy2-dy1)/ny)**2)
460+
error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2)
459461

460462
# Only save step if within error tolerance
461463
if error < maxerror:
@@ -473,7 +475,7 @@ def _integrate_rk12(x0, y0, dmap, f):
473475
if error == 0:
474476
ds = maxds
475477
else:
476-
ds = min(maxds, 0.85 * ds * (maxerror/error)**0.5)
478+
ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5)
477479

478480
return stotal, xf_traj, yf_traj
479481

@@ -497,8 +499,8 @@ def _euler_step(xf_traj, yf_traj, dmap, f):
497499
else:
498500
dsy = (ny - 1 - yi) / cy
499501
ds = min(dsx, dsy)
500-
xf_traj.append(xi + cx*ds)
501-
yf_traj.append(yi + cy*ds)
502+
xf_traj.append(xi + cx * ds)
503+
yf_traj.append(yi + cy * ds)
502504
return ds, xf_traj, yf_traj
503505

504506

@@ -519,10 +521,14 @@ def interpgrid(a, xi, yi):
519521
x = np.int(xi)
520522
y = np.int(yi)
521523
# conditional is faster than clipping for integers
522-
if x == (Nx - 2): xn = x
523-
else: xn = x + 1
524-
if y == (Ny - 2): yn = y
525-
else: yn = y + 1
524+
if x == (Nx - 2):
525+
xn = x
526+
else:
527+
xn = x + 1
528+
if y == (Ny - 2):
529+
yn = y
530+
else:
531+
yn = y + 1
526532

527533
a00 = a[y, x]
528534
a01 = a[y, xn]
@@ -563,20 +569,20 @@ def _gen_starting_points(shape):
563569
if direction == 'right':
564570
x += 1
565571
if x >= xlast:
566-
xlast -=1
572+
xlast -= 1
567573
direction = 'up'
568574
elif direction == 'up':
569575
y += 1
570576
if y >= ylast:
571-
ylast -=1
577+
ylast -= 1
572578
direction = 'left'
573579
elif direction == 'left':
574580
x -= 1
575581
if x <= xfirst:
576-
xfirst +=1
582+
xfirst += 1
577583
direction = 'down'
578584
elif direction == 'down':
579585
y -= 1
580586
if y <= yfirst:
581-
yfirst +=1
587+
yfirst += 1
582588
direction = 'right'

0 commit comments

Comments
 (0)