|
1 | 1 | #!/usr/bin/env python |
2 | | -# Time-stamp: <2010-02-10 01:49:08 ycopin> |
3 | 2 |
|
4 | | -import numpy as np |
5 | | -import matplotlib.pyplot as plt |
6 | | -import matplotlib.patches as mpatches |
7 | | -from matplotlib.path import Path |
| 3 | +__author__ = "Yannick Copin <[email protected]>" |
| 4 | +__version__ = "Time-stamp: <10/02/2010 16:49 [email protected]>" |
8 | 5 |
|
9 | | -def sankey(ax, losses, labels=None, |
10 | | - dx=40, dy=10, angle=45, w=3, dip=10, offset=2, **kwargs): |
| 6 | +import numpy as N |
| 7 | + |
| 8 | +def sankey(ax, |
| 9 | + outputs=[100.], outlabels=None, |
| 10 | + inputs=[100.], inlabels='', |
| 11 | + dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs): |
11 | 12 | """Draw a Sankey diagram. |
12 | 13 |
|
13 | | - losses: array of losses, should sum up to 100% |
14 | | - labels: loss labels (same length as losses), |
15 | | - or None (use default labels) or '' (no labels) |
| 14 | + outputs: array of outputs, should sum up to 100% |
| 15 | + outlabels: output labels (same length as outputs), |
| 16 | + or None (use default labels) or '' (no labels) |
| 17 | + inputs and inlabels: similar for inputs |
16 | 18 | dx: horizontal elongation |
17 | 19 | dy: vertical elongation |
18 | | - angle: arrow angle [deg] |
19 | | - w: arrow shoulder |
20 | | - dip: input dip |
| 20 | + outangle: output arrow angle [deg] |
| 21 | + w: output arrow shoulder |
| 22 | + inangle: input dip angle |
21 | 23 | offset: text offset |
22 | 24 | **kwargs: propagated to Patch (e.g. fill=False) |
23 | 25 |
|
24 | | - Return (patch,texts).""" |
| 26 | + Return (patch,[intexts,outtexts]).""" |
| 27 | + |
| 28 | + import matplotlib.patches as mpatches |
| 29 | + from matplotlib.path import Path |
| 30 | + |
| 31 | + outs = N.absolute(outputs) |
| 32 | + outsigns = N.sign(outputs) |
| 33 | + outsigns[-1] = 0 # Last output |
25 | 34 |
|
26 | | - assert sum(losses)==100, "Input losses don't sum up to 100%" |
| 35 | + ins = N.absolute(inputs) |
| 36 | + insigns = N.sign(inputs) |
| 37 | + insigns[0] = 0 # First input |
27 | 38 |
|
28 | | - def add_loss(loss, last=False): |
29 | | - h = (loss/2+w)*np.tan(angle/180.*np.pi) # Arrow tip height |
| 39 | + assert sum(outs)==100, "Outputs don't sum up to 100%" |
| 40 | + assert sum(ins)==100, "Inputs don't sum up to 100%" |
| 41 | + |
| 42 | + def add_output(path, loss, sign=1): |
| 43 | + h = (loss/2+w)*N.tan(outangle/180.*N.pi) # Arrow tip height |
30 | 44 | move,(x,y) = path[-1] # Use last point as reference |
31 | | - if last: # Final loss (horizontal) |
| 45 | + if sign==0: # Final loss (horizontal) |
32 | 46 | path.extend([(Path.LINETO,[x+dx,y]), |
33 | 47 | (Path.LINETO,[x+dx,y+w]), |
34 | 48 | (Path.LINETO,[x+dx+h,y-loss/2]), # Tip |
35 | 49 | (Path.LINETO,[x+dx,y-loss-w]), |
36 | 50 | (Path.LINETO,[x+dx,y-loss])]) |
37 | | - tips.append(path[-3][1]) |
| 51 | + outtips.append((sign,path[-3][1])) |
38 | 52 | else: # Intermediate loss (vertical) |
39 | | - path.extend([(Path.LINETO,[x+dx/2,y]), |
40 | | - (Path.CURVE3,[x+dx,y]), |
41 | | - (Path.CURVE3,[x+dx,y+dy]), |
42 | | - (Path.LINETO,[x+dx-w,y+dy]), |
43 | | - (Path.LINETO,[x+dx+loss/2,y+dy+h]), # Tip |
44 | | - (Path.LINETO,[x+dx+loss+w,y+dy]), |
45 | | - (Path.LINETO,[x+dx+loss,y+dy]), |
46 | | - (Path.CURVE3,[x+dx+loss,y-loss]), |
47 | | - (Path.CURVE3,[x+dx/2+loss,y-loss])]) |
48 | | - tips.append(path[-5][1]) |
49 | | - |
50 | | - tips = [] # Arrow tip positions |
51 | | - path = [(Path.MOVETO,[0,100])] # 1st point |
52 | | - for i,loss in enumerate(losses): |
53 | | - add_loss(loss, last=(i==(len(losses)-1))) |
54 | | - path.extend([(Path.LINETO,[0,0]), |
55 | | - (Path.LINETO,[dip,50]), # Dip |
56 | | - (Path.CLOSEPOLY,[0,100])]) |
| 53 | + path.extend([(Path.CURVE4,[x+dx/2,y]), |
| 54 | + (Path.CURVE4,[x+dx,y]), |
| 55 | + (Path.CURVE4,[x+dx,y+sign*dy]), |
| 56 | + (Path.LINETO,[x+dx-w,y+sign*dy]), |
| 57 | + (Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip |
| 58 | + (Path.LINETO,[x+dx+loss+w,y+sign*dy]), |
| 59 | + (Path.LINETO,[x+dx+loss,y+sign*dy]), |
| 60 | + (Path.CURVE3,[x+dx+loss,y-sign*loss]), |
| 61 | + (Path.CURVE3,[x+dx/2+loss,y-sign*loss])]) |
| 62 | + outtips.append((sign,path[-5][1])) |
| 63 | + |
| 64 | + def add_input(path, gain, sign=1): |
| 65 | + h = (gain/2)*N.tan(inangle/180.*N.pi) # Dip depth |
| 66 | + move,(x,y) = path[-1] # Use last point as reference |
| 67 | + if sign==0: # First gain (horizontal) |
| 68 | + path.extend([(Path.LINETO,[x-dx,y]), |
| 69 | + (Path.LINETO,[x-dx+h,y+gain/2]), # Dip |
| 70 | + (Path.LINETO,[x-dx,y+gain])]) |
| 71 | + xd,yd = path[-2][1] # Dip position |
| 72 | + indips.append((sign,[xd-h,yd])) |
| 73 | + else: # Intermediate gain (vertical) |
| 74 | + path.extend([(Path.CURVE4,[x-dx/2,y]), |
| 75 | + (Path.CURVE4,[x-dx,y]), |
| 76 | + (Path.CURVE4,[x-dx,y+sign*dy]), |
| 77 | + (Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip |
| 78 | + (Path.LINETO,[x-dx-gain,y+sign*dy]), |
| 79 | + (Path.CURVE3,[x-dx-gain,y-sign*gain]), |
| 80 | + (Path.CURVE3,[x-dx/2-gain,y-sign*gain])]) |
| 81 | + xd,yd = path[-4][1] # Dip position |
| 82 | + indips.append((sign,[xd,yd+sign*h])) |
| 83 | + |
| 84 | + outtips = [] # Output arrow tip dir. and positions |
| 85 | + urpath = [(Path.MOVETO,[0,100])] # 1st point of upper right path |
| 86 | + lrpath = [(Path.LINETO,[0,0])] # 1st point of lower right path |
| 87 | + for loss,sign in zip(outs,outsigns): |
| 88 | + add_output(sign>=0 and urpath or lrpath, loss, sign=sign) |
| 89 | + |
| 90 | + indips = [] # Input arrow tip dir. and positions |
| 91 | + llpath = [(Path.LINETO,[0,0])] # 1st point of lower left path |
| 92 | + ulpath = [(Path.MOVETO,[0,100])] # 1st point of upper left path |
| 93 | + for gain,sign in zip(ins,insigns)[::-1]: |
| 94 | + add_input(sign<=0 and llpath or ulpath, gain, sign=sign) |
| 95 | + |
| 96 | + def revert(path): |
| 97 | + """A path is not just revertable by path[::-1] because of Bezier |
| 98 | + curves.""" |
| 99 | + rpath = [] |
| 100 | + nextmove = Path.LINETO |
| 101 | + for move,pos in path[::-1]: |
| 102 | + rpath.append((nextmove,pos)) |
| 103 | + nextmove = move |
| 104 | + return rpath |
| 105 | + |
| 106 | + # Concatenate subpathes in correct order |
| 107 | + path = urpath + revert(lrpath) + llpath + revert(ulpath) |
| 108 | + |
57 | 109 | codes,verts = zip(*path) |
58 | | - verts = np.array(verts) |
| 110 | + verts = N.array(verts) |
59 | 111 |
|
60 | 112 | # Path patch |
61 | 113 | path = Path(verts,codes) |
62 | 114 | patch = mpatches.PathPatch(path, **kwargs) |
63 | 115 | ax.add_patch(patch) |
64 | 116 |
|
| 117 | + if False: # DEBUG |
| 118 | + print "urpath", urpath |
| 119 | + print "lrpath", revert(lrpath) |
| 120 | + print "llpath", llpath |
| 121 | + print "ulpath", revert(ulpath) |
| 122 | + |
| 123 | + xs,ys = zip(*verts) |
| 124 | + ax.plot(xs,ys,'go-') |
| 125 | + |
65 | 126 | # Labels |
66 | | - if labels=='': # No labels |
67 | | - pass |
68 | | - elif labels is None: # Default labels |
69 | | - labels = [ '%2d%%' % loss for loss in losses ] |
70 | | - else: |
71 | | - assert len(labels)==len(losses) |
72 | | - |
73 | | - texts = [] |
74 | | - for i,label in enumerate(labels): |
75 | | - x,y = tips[i] # Label position |
76 | | - last = (i==(len(losses)-1)) |
77 | | - if last: |
78 | | - t = ax.text(x+offset,y,label, ha='left', va='center') |
| 127 | + |
| 128 | + def set_labels(labels,values): |
| 129 | + """Set or check labels according to values.""" |
| 130 | + if labels=='': # No labels |
| 131 | + return labels |
| 132 | + elif labels is None: # Default labels |
| 133 | + return [ '%2d%%' % val for val in values ] |
79 | 134 | else: |
80 | | - t = ax.text(x,y+offset,label, ha='center', va='bottom') |
81 | | - texts.append(t) |
| 135 | + assert len(labels)==len(values) |
| 136 | + return labels |
| 137 | + |
| 138 | + def put_labels(labels,positions,output=True): |
| 139 | + """Put labels to positions.""" |
| 140 | + texts = [] |
| 141 | + lbls = output and labels or labels[::-1] |
| 142 | + for i,label in enumerate(lbls): |
| 143 | + s,(x,y) = positions[i] # Label direction and position |
| 144 | + if s==0: |
| 145 | + t = ax.text(x+offset,y,label, |
| 146 | + ha=output and 'left' or 'right', va='center') |
| 147 | + elif s>0: |
| 148 | + t = ax.text(x,y+offset,label, ha='center', va='bottom') |
| 149 | + else: |
| 150 | + t = ax.text(x,y-offset,label, ha='center', va='top') |
| 151 | + texts.append(t) |
| 152 | + return texts |
| 153 | + |
| 154 | + outlabels = set_labels(outlabels, outs) |
| 155 | + outtexts = put_labels(outlabels, outtips, output=True) |
| 156 | + |
| 157 | + inlabels = set_labels(inlabels, ins) |
| 158 | + intexts = put_labels(inlabels, indips, output=False) |
82 | 159 |
|
83 | 160 | # Axes management |
84 | | - ax.set_xlim(verts[:,0].min()-10, verts[:,0].max()+40) |
85 | | - ax.set_ylim(verts[:,1].min()-10, verts[:,1].max()+20) |
| 161 | + ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx) |
| 162 | + ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy) |
86 | 163 | ax.set_aspect('equal', adjustable='datalim') |
87 | | - ax.set_xticks([]) |
88 | | - ax.set_yticks([]) |
89 | 164 |
|
90 | | - return patch,texts |
| 165 | + return patch,[intexts,outtexts] |
91 | 166 |
|
92 | 167 | if __name__=='__main__': |
93 | 168 |
|
94 | | - losses = [10.,20.,5.,15.,10.,40.] |
95 | | - labels = ['First','Second','Third','Fourth','Fifth','Hurray!'] |
96 | | - labels = [ s+'\n%d%%' % l for l,s in zip(losses,labels) ] |
| 169 | + import matplotlib.pyplot as P |
| 170 | + |
| 171 | + outputs = [10.,-20.,5.,15.,-10.,40.] |
| 172 | + outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!'] |
| 173 | + outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ] |
| 174 | + |
| 175 | + inputs = [60.,-25.,15.] |
97 | 176 |
|
98 | | - fig = plt.figure() |
99 | | - ax = fig.add_subplot(1,1,1) |
| 177 | + fig = P.figure() |
| 178 | + ax = fig.add_subplot(1,1,1, xticks=[],yticks=[], |
| 179 | + title="Sankey diagram" |
| 180 | + ) |
100 | 181 |
|
101 | | - patch,texts = sankey(ax, losses, labels, fc='g', alpha=0.2) |
102 | | - texts[1].set_color('r') |
103 | | - texts[-1].set_fontweight('bold') |
| 182 | + patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels, |
| 183 | + inputs=inputs, inlabels=None, |
| 184 | + fc='g', alpha=0.2) |
| 185 | + outtexts[1].set_color('r') |
| 186 | + outtexts[-1].set_fontweight('bold') |
104 | 187 |
|
105 | | - plt.show() |
| 188 | + P.show() |
0 commit comments