|
| 1 | +#!/usr/bin/env python |
| 2 | +# Time-stamp: <2010-02-10 01:49:08 ycopin> |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import matplotlib.patches as mpatches |
| 7 | +from matplotlib.path import Path |
| 8 | + |
| 9 | +def sankey(ax, losses, labels=None, |
| 10 | + dx=40, dy=10, angle=45, w=3, dip=10, offset=2, **kwargs): |
| 11 | + """Draw a Sankey diagram. |
| 12 | +
|
| 13 | + losses: array of losses, should sum up to 100% |
| 14 | + labels: loss labels (same length as losses), |
| 15 | + or None (use default labels) or '' (no labels) |
| 16 | + dx: horizontal elongation |
| 17 | + dy: vertical elongation |
| 18 | + angle: arrow angle [deg] |
| 19 | + w: arrow shoulder |
| 20 | + dip: input dip |
| 21 | + offset: text offset |
| 22 | + **kwargs: propagated to Patch (e.g. fill=False) |
| 23 | +
|
| 24 | + Return (patch,texts).""" |
| 25 | + |
| 26 | + assert sum(losses)==100, "Input losses don't sum up to 100%" |
| 27 | + |
| 28 | + def add_loss(loss, last=False): |
| 29 | + h = (loss/2+w)*np.tan(angle/180.*np.pi) # Arrow tip height |
| 30 | + move,(x,y) = path[-1] # Use last point as reference |
| 31 | + if last: # Final loss (horizontal) |
| 32 | + path.extend([(Path.LINETO,[x+dx,y]), |
| 33 | + (Path.LINETO,[x+dx,y+w]), |
| 34 | + (Path.LINETO,[x+dx+h,y-loss/2]), # Tip |
| 35 | + (Path.LINETO,[x+dx,y-loss-w]), |
| 36 | + (Path.LINETO,[x+dx,y-loss])]) |
| 37 | + tips.append(path[-3][1]) |
| 38 | + else: # Intermediate loss (vertical) |
| 39 | + path.extend([(Path.LINETO,[x+dx/2,y]), |
| 40 | + (Path.CURVE3,[x+dx,y]), |
| 41 | + (Path.CURVE3,[x+dx,y+dy]), |
| 42 | + (Path.LINETO,[x+dx-w,y+dy]), |
| 43 | + (Path.LINETO,[x+dx+loss/2,y+dy+h]), # Tip |
| 44 | + (Path.LINETO,[x+dx+loss+w,y+dy]), |
| 45 | + (Path.LINETO,[x+dx+loss,y+dy]), |
| 46 | + (Path.CURVE3,[x+dx+loss,y-loss]), |
| 47 | + (Path.CURVE3,[x+dx/2+loss,y-loss])]) |
| 48 | + tips.append(path[-5][1]) |
| 49 | + |
| 50 | + tips = [] # Arrow tip positions |
| 51 | + path = [(Path.MOVETO,[0,100])] # 1st point |
| 52 | + for i,loss in enumerate(losses): |
| 53 | + add_loss(loss, last=(i==(len(losses)-1))) |
| 54 | + path.extend([(Path.LINETO,[0,0]), |
| 55 | + (Path.LINETO,[dip,50]), # Dip |
| 56 | + (Path.CLOSEPOLY,[0,100])]) |
| 57 | + codes,verts = zip(*path) |
| 58 | + verts = np.array(verts) |
| 59 | + |
| 60 | + # Path patch |
| 61 | + path = Path(verts,codes) |
| 62 | + patch = mpatches.PathPatch(path, **kwargs) |
| 63 | + ax.add_patch(patch) |
| 64 | + |
| 65 | + # Labels |
| 66 | + if labels=='': # No labels |
| 67 | + pass |
| 68 | + elif labels is None: # Default labels |
| 69 | + labels = [ '%2d%%' % loss for loss in losses ] |
| 70 | + else: |
| 71 | + assert len(labels)==len(losses) |
| 72 | + |
| 73 | + texts = [] |
| 74 | + for i,label in enumerate(labels): |
| 75 | + x,y = tips[i] # Label position |
| 76 | + last = (i==(len(losses)-1)) |
| 77 | + if last: |
| 78 | + t = ax.text(x+offset,y,label, ha='left', va='center') |
| 79 | + else: |
| 80 | + t = ax.text(x,y+offset,label, ha='center', va='bottom') |
| 81 | + texts.append(t) |
| 82 | + |
| 83 | + # Axes management |
| 84 | + ax.set_xlim(verts[:,0].min()-10, verts[:,0].max()+40) |
| 85 | + ax.set_ylim(verts[:,1].min()-10, verts[:,1].max()+20) |
| 86 | + ax.set_aspect('equal', adjustable='datalim') |
| 87 | + ax.set_xticks([]) |
| 88 | + ax.set_yticks([]) |
| 89 | + |
| 90 | + return patch,texts |
| 91 | + |
| 92 | +if __name__=='__main__': |
| 93 | + |
| 94 | + losses = [10.,20.,5.,15.,10.,40.] |
| 95 | + labels = ['First','Second','Third','Fourth','Fifth','Hurray!'] |
| 96 | + labels = [ s+'\n%d%%' % l for l,s in zip(losses,labels) ] |
| 97 | + |
| 98 | + fig = plt.figure() |
| 99 | + ax = fig.add_subplot(1,1,1) |
| 100 | + |
| 101 | + patch,texts = sankey(ax, losses, labels, fc='g', alpha=0.2) |
| 102 | + texts[1].set_color('r') |
| 103 | + texts[-1].set_fontweight('bold') |
| 104 | + |
| 105 | + plt.show() |
0 commit comments