Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 30daf2f

Browse files
committed
added Yannick Copin's updated sanke demo
svn path=/trunk/matplotlib/; revision=8124
1 parent 2286433 commit 30daf2f

1 file changed

Lines changed: 150 additions & 67 deletions

File tree

examples/api/sankey_demo.py

Lines changed: 150 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,188 @@
11
#!/usr/bin/env python
2-
# Time-stamp: <2010-02-10 01:49:08 ycopin>
32

4-
import numpy as np
5-
import matplotlib.pyplot as plt
6-
import matplotlib.patches as mpatches
7-
from matplotlib.path import Path
3+
__author__ = "Yannick Copin <[email protected]>"
4+
__version__ = "Time-stamp: <10/02/2010 16:49 [email protected]>"
85

9-
def sankey(ax, losses, labels=None,
10-
dx=40, dy=10, angle=45, w=3, dip=10, offset=2, **kwargs):
6+
import numpy as N
7+
8+
def sankey(ax,
9+
outputs=[100.], outlabels=None,
10+
inputs=[100.], inlabels='',
11+
dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
1112
"""Draw a Sankey diagram.
1213
13-
losses: array of losses, should sum up to 100%
14-
labels: loss labels (same length as losses),
15-
or None (use default labels) or '' (no labels)
14+
outputs: array of outputs, should sum up to 100%
15+
outlabels: output labels (same length as outputs),
16+
or None (use default labels) or '' (no labels)
17+
inputs and inlabels: similar for inputs
1618
dx: horizontal elongation
1719
dy: vertical elongation
18-
angle: arrow angle [deg]
19-
w: arrow shoulder
20-
dip: input dip
20+
outangle: output arrow angle [deg]
21+
w: output arrow shoulder
22+
inangle: input dip angle
2123
offset: text offset
2224
**kwargs: propagated to Patch (e.g. fill=False)
2325
24-
Return (patch,texts)."""
26+
Return (patch,[intexts,outtexts])."""
27+
28+
import matplotlib.patches as mpatches
29+
from matplotlib.path import Path
30+
31+
outs = N.absolute(outputs)
32+
outsigns = N.sign(outputs)
33+
outsigns[-1] = 0 # Last output
2534

26-
assert sum(losses)==100, "Input losses don't sum up to 100%"
35+
ins = N.absolute(inputs)
36+
insigns = N.sign(inputs)
37+
insigns[0] = 0 # First input
2738

28-
def add_loss(loss, last=False):
29-
h = (loss/2+w)*np.tan(angle/180.*np.pi) # Arrow tip height
39+
assert sum(outs)==100, "Outputs don't sum up to 100%"
40+
assert sum(ins)==100, "Inputs don't sum up to 100%"
41+
42+
def add_output(path, loss, sign=1):
43+
h = (loss/2+w)*N.tan(outangle/180.*N.pi) # Arrow tip height
3044
move,(x,y) = path[-1] # Use last point as reference
31-
if last: # Final loss (horizontal)
45+
if sign==0: # Final loss (horizontal)
3246
path.extend([(Path.LINETO,[x+dx,y]),
3347
(Path.LINETO,[x+dx,y+w]),
3448
(Path.LINETO,[x+dx+h,y-loss/2]), # Tip
3549
(Path.LINETO,[x+dx,y-loss-w]),
3650
(Path.LINETO,[x+dx,y-loss])])
37-
tips.append(path[-3][1])
51+
outtips.append((sign,path[-3][1]))
3852
else: # Intermediate loss (vertical)
39-
path.extend([(Path.LINETO,[x+dx/2,y]),
40-
(Path.CURVE3,[x+dx,y]),
41-
(Path.CURVE3,[x+dx,y+dy]),
42-
(Path.LINETO,[x+dx-w,y+dy]),
43-
(Path.LINETO,[x+dx+loss/2,y+dy+h]), # Tip
44-
(Path.LINETO,[x+dx+loss+w,y+dy]),
45-
(Path.LINETO,[x+dx+loss,y+dy]),
46-
(Path.CURVE3,[x+dx+loss,y-loss]),
47-
(Path.CURVE3,[x+dx/2+loss,y-loss])])
48-
tips.append(path[-5][1])
49-
50-
tips = [] # Arrow tip positions
51-
path = [(Path.MOVETO,[0,100])] # 1st point
52-
for i,loss in enumerate(losses):
53-
add_loss(loss, last=(i==(len(losses)-1)))
54-
path.extend([(Path.LINETO,[0,0]),
55-
(Path.LINETO,[dip,50]), # Dip
56-
(Path.CLOSEPOLY,[0,100])])
53+
path.extend([(Path.CURVE4,[x+dx/2,y]),
54+
(Path.CURVE4,[x+dx,y]),
55+
(Path.CURVE4,[x+dx,y+sign*dy]),
56+
(Path.LINETO,[x+dx-w,y+sign*dy]),
57+
(Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip
58+
(Path.LINETO,[x+dx+loss+w,y+sign*dy]),
59+
(Path.LINETO,[x+dx+loss,y+sign*dy]),
60+
(Path.CURVE3,[x+dx+loss,y-sign*loss]),
61+
(Path.CURVE3,[x+dx/2+loss,y-sign*loss])])
62+
outtips.append((sign,path[-5][1]))
63+
64+
def add_input(path, gain, sign=1):
65+
h = (gain/2)*N.tan(inangle/180.*N.pi) # Dip depth
66+
move,(x,y) = path[-1] # Use last point as reference
67+
if sign==0: # First gain (horizontal)
68+
path.extend([(Path.LINETO,[x-dx,y]),
69+
(Path.LINETO,[x-dx+h,y+gain/2]), # Dip
70+
(Path.LINETO,[x-dx,y+gain])])
71+
xd,yd = path[-2][1] # Dip position
72+
indips.append((sign,[xd-h,yd]))
73+
else: # Intermediate gain (vertical)
74+
path.extend([(Path.CURVE4,[x-dx/2,y]),
75+
(Path.CURVE4,[x-dx,y]),
76+
(Path.CURVE4,[x-dx,y+sign*dy]),
77+
(Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
78+
(Path.LINETO,[x-dx-gain,y+sign*dy]),
79+
(Path.CURVE3,[x-dx-gain,y-sign*gain]),
80+
(Path.CURVE3,[x-dx/2-gain,y-sign*gain])])
81+
xd,yd = path[-4][1] # Dip position
82+
indips.append((sign,[xd,yd+sign*h]))
83+
84+
outtips = [] # Output arrow tip dir. and positions
85+
urpath = [(Path.MOVETO,[0,100])] # 1st point of upper right path
86+
lrpath = [(Path.LINETO,[0,0])] # 1st point of lower right path
87+
for loss,sign in zip(outs,outsigns):
88+
add_output(sign>=0 and urpath or lrpath, loss, sign=sign)
89+
90+
indips = [] # Input arrow tip dir. and positions
91+
llpath = [(Path.LINETO,[0,0])] # 1st point of lower left path
92+
ulpath = [(Path.MOVETO,[0,100])] # 1st point of upper left path
93+
for gain,sign in zip(ins,insigns)[::-1]:
94+
add_input(sign<=0 and llpath or ulpath, gain, sign=sign)
95+
96+
def revert(path):
97+
"""A path is not just revertable by path[::-1] because of Bezier
98+
curves."""
99+
rpath = []
100+
nextmove = Path.LINETO
101+
for move,pos in path[::-1]:
102+
rpath.append((nextmove,pos))
103+
nextmove = move
104+
return rpath
105+
106+
# Concatenate subpathes in correct order
107+
path = urpath + revert(lrpath) + llpath + revert(ulpath)
108+
57109
codes,verts = zip(*path)
58-
verts = np.array(verts)
110+
verts = N.array(verts)
59111

60112
# Path patch
61113
path = Path(verts,codes)
62114
patch = mpatches.PathPatch(path, **kwargs)
63115
ax.add_patch(patch)
64116

117+
if False: # DEBUG
118+
print "urpath", urpath
119+
print "lrpath", revert(lrpath)
120+
print "llpath", llpath
121+
print "ulpath", revert(ulpath)
122+
123+
xs,ys = zip(*verts)
124+
ax.plot(xs,ys,'go-')
125+
65126
# Labels
66-
if labels=='': # No labels
67-
pass
68-
elif labels is None: # Default labels
69-
labels = [ '%2d%%' % loss for loss in losses ]
70-
else:
71-
assert len(labels)==len(losses)
72-
73-
texts = []
74-
for i,label in enumerate(labels):
75-
x,y = tips[i] # Label position
76-
last = (i==(len(losses)-1))
77-
if last:
78-
t = ax.text(x+offset,y,label, ha='left', va='center')
127+
128+
def set_labels(labels,values):
129+
"""Set or check labels according to values."""
130+
if labels=='': # No labels
131+
return labels
132+
elif labels is None: # Default labels
133+
return [ '%2d%%' % val for val in values ]
79134
else:
80-
t = ax.text(x,y+offset,label, ha='center', va='bottom')
81-
texts.append(t)
135+
assert len(labels)==len(values)
136+
return labels
137+
138+
def put_labels(labels,positions,output=True):
139+
"""Put labels to positions."""
140+
texts = []
141+
lbls = output and labels or labels[::-1]
142+
for i,label in enumerate(lbls):
143+
s,(x,y) = positions[i] # Label direction and position
144+
if s==0:
145+
t = ax.text(x+offset,y,label,
146+
ha=output and 'left' or 'right', va='center')
147+
elif s>0:
148+
t = ax.text(x,y+offset,label, ha='center', va='bottom')
149+
else:
150+
t = ax.text(x,y-offset,label, ha='center', va='top')
151+
texts.append(t)
152+
return texts
153+
154+
outlabels = set_labels(outlabels, outs)
155+
outtexts = put_labels(outlabels, outtips, output=True)
156+
157+
inlabels = set_labels(inlabels, ins)
158+
intexts = put_labels(inlabels, indips, output=False)
82159

83160
# Axes management
84-
ax.set_xlim(verts[:,0].min()-10, verts[:,0].max()+40)
85-
ax.set_ylim(verts[:,1].min()-10, verts[:,1].max()+20)
161+
ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
162+
ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)
86163
ax.set_aspect('equal', adjustable='datalim')
87-
ax.set_xticks([])
88-
ax.set_yticks([])
89164

90-
return patch,texts
165+
return patch,[intexts,outtexts]
91166

92167
if __name__=='__main__':
93168

94-
losses = [10.,20.,5.,15.,10.,40.]
95-
labels = ['First','Second','Third','Fourth','Fifth','Hurray!']
96-
labels = [ s+'\n%d%%' % l for l,s in zip(losses,labels) ]
169+
import matplotlib.pyplot as P
170+
171+
outputs = [10.,-20.,5.,15.,-10.,40.]
172+
outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
173+
outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]
174+
175+
inputs = [60.,-25.,15.]
97176

98-
fig = plt.figure()
99-
ax = fig.add_subplot(1,1,1)
177+
fig = P.figure()
178+
ax = fig.add_subplot(1,1,1, xticks=[],yticks=[],
179+
title="Sankey diagram"
180+
)
100181

101-
patch,texts = sankey(ax, losses, labels, fc='g', alpha=0.2)
102-
texts[1].set_color('r')
103-
texts[-1].set_fontweight('bold')
182+
patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
183+
inputs=inputs, inlabels=None,
184+
fc='g', alpha=0.2)
185+
outtexts[1].set_color('r')
186+
outtexts[-1].set_fontweight('bold')
104187

105-
plt.show()
188+
P.show()

0 commit comments

Comments
 (0)