77
88import numpy as np
99
10+
1011def sankey (ax ,
1112 outputs = [100. ], outlabels = None ,
1213 inputs = [100. ], inlabels = '' ,
1314 dx = 40 , dy = 10 , outangle = 45 , w = 3 , inangle = 30 , offset = 2 , ** kwargs ):
1415 """Draw a Sankey diagram.
1516
16- outputs: array of outputs, should sum up to 100%
17- outlabels: output labels (same length as outputs),
18- or None (use default labels) or '' (no labels)
19- inputs and inlabels: similar for inputs
20- dx: horizontal elongation
21- dy: vertical elongation
22- outangle: output arrow angle [deg]
23- w: output arrow shoulder
24- inangle: input dip angle
25- offset: text offset
26- **kwargs: propagated to Patch (e.g. fill=False)
27-
28- Return (patch,[intexts,outtexts])."""
29-
17+ outputs: array of outputs, should sum up to 100%
18+ outlabels: output labels (same length as outputs),
19+ or None (use default labels) or '' (no labels)
20+ inputs and inlabels: similar for inputs
21+ dx: horizontal elongation
22+ dy: vertical elongation
23+ outangle: output arrow angle [deg]
24+ w: output arrow shoulder
25+ inangle: input dip angle
26+ offset: text offset
27+ **kwargs: propagated to Patch (e.g. fill=False)
28+
29+ Return (patch,[intexts,outtexts]).
30+ """
3031 import matplotlib .patches as mpatches
3132 from matplotlib .path import Path
3233
3334 outs = np .absolute (outputs )
3435 outsigns = np .sign (outputs )
35- outsigns [- 1 ] = 0 # Last output
36+ outsigns [- 1 ] = 0 # Last output
3637
3738 ins = np .absolute (inputs )
3839 insigns = np .sign (inputs )
39- insigns [0 ] = 0 # First input
40+ insigns [0 ] = 0 # First input
4041
41- assert sum (outs )== 100 , "Outputs don't sum up to 100%"
42- assert sum (ins )== 100 , "Inputs don't sum up to 100%"
42+ assert sum (outs ) == 100 , "Outputs don't sum up to 100%"
43+ assert sum (ins ) == 100 , "Inputs don't sum up to 100%"
4344
4445 def add_output (path , loss , sign = 1 ):
45- h = (loss / 2 + w )* np .tan (outangle / 180. * np .pi ) # Arrow tip height
46- move ,(x ,y ) = path [- 1 ] # Use last point as reference
47- if sign == 0 : # Final loss (horizontal)
48- path .extend ([(Path .LINETO ,[x + dx ,y ]),
49- (Path .LINETO ,[x + dx ,y + w ]),
50- (Path .LINETO ,[x + dx + h ,y - loss / 2 ]), # Tip
51- (Path .LINETO ,[x + dx ,y - loss - w ]),
52- (Path .LINETO ,[x + dx ,y - loss ])])
53- outtips .append ((sign ,path [- 3 ][1 ]))
54- else : # Intermediate loss (vertical)
55- path .extend ([(Path .CURVE4 ,[x + dx / 2 ,y ]),
56- (Path .CURVE4 ,[x + dx ,y ]),
57- (Path .CURVE4 ,[x + dx ,y + sign * dy ]),
58- (Path .LINETO ,[x + dx - w ,y + sign * dy ]),
59- (Path .LINETO ,[x + dx + loss / 2 ,y + sign * (dy + h )]), # Tip
60- (Path .LINETO ,[x + dx + loss + w ,y + sign * dy ]),
61- (Path .LINETO ,[x + dx + loss ,y + sign * dy ]),
62- (Path .CURVE3 ,[x + dx + loss ,y - sign * loss ]),
63- (Path .CURVE3 ,[x + dx / 2 + loss ,y - sign * loss ])])
64- outtips .append ((sign ,path [- 5 ][1 ]))
46+ h = (loss / 2 + w )* np .tan (outangle / 180. * np .pi ) # Arrow tip height
47+ move , (x , y ) = path [- 1 ] # Use last point as reference
48+ if sign == 0 : # Final loss (horizontal)
49+ path .extend ([(Path .LINETO , [x + dx , y ]),
50+ (Path .LINETO , [x + dx , y + w ]),
51+ (Path .LINETO , [x + dx + h , y - loss / 2 ]), # Tip
52+ (Path .LINETO , [x + dx , y - loss - w ]),
53+ (Path .LINETO , [x + dx , y - loss ])])
54+ outtips .append ((sign , path [- 3 ][1 ]))
55+ else : # Intermediate loss (vertical)
56+ path .extend ([(Path .CURVE4 , [x + dx / 2 , y ]),
57+ (Path .CURVE4 , [x + dx , y ]),
58+ (Path .CURVE4 , [x + dx , y + sign * dy ]),
59+ (Path .LINETO , [x + dx - w , y + sign * dy ]),
60+ (Path .LINETO , [x + dx + loss / 2 , y + sign * (dy + h )]), # Tip
61+ (Path .LINETO , [x + dx + loss + w , y + sign * dy ]),
62+ (Path .LINETO , [x + dx + loss , y + sign * dy ]),
63+ (Path .CURVE3 , [x + dx + loss , y - sign * loss ]),
64+ (Path .CURVE3 , [x + dx / 2 + loss , y - sign * loss ])])
65+ outtips .append ((sign , path [- 5 ][1 ]))
6566
6667 def add_input (path , gain , sign = 1 ):
67- h = (gain / 2 )* np .tan (inangle / 180. * np .pi ) # Dip depth
68- move ,(x ,y ) = path [- 1 ] # Use last point as reference
69- if sign == 0 : # First gain (horizontal)
70- path .extend ([(Path .LINETO ,[x - dx ,y ]),
71- (Path .LINETO ,[x - dx + h ,y + gain / 2 ]), # Dip
72- (Path .LINETO ,[x - dx ,y + gain ])])
73- xd ,yd = path [- 2 ][1 ] # Dip position
74- indips .append ((sign ,[xd - h ,yd ]))
75- else : # Intermediate gain (vertical)
76- path .extend ([(Path .CURVE4 ,[x - dx / 2 ,y ]),
77- (Path .CURVE4 ,[x - dx ,y ]),
78- (Path .CURVE4 ,[x - dx ,y + sign * dy ]),
79- (Path .LINETO ,[x - dx - gain / 2 ,y + sign * (dy - h )]), # Dip
80- (Path .LINETO ,[x - dx - gain ,y + sign * dy ]),
81- (Path .CURVE3 ,[x - dx - gain ,y - sign * gain ]),
82- (Path .CURVE3 ,[x - dx / 2 - gain ,y - sign * gain ])])
83- xd ,yd = path [- 4 ][1 ] # Dip position
84- indips .append ((sign ,[xd ,yd + sign * h ]))
85-
86- outtips = [] # Output arrow tip dir. and positions
87- urpath = [(Path .MOVETO ,[0 ,100 ])] # 1st point of upper right path
88- lrpath = [(Path .LINETO ,[0 ,0 ])] # 1st point of lower right path
89- for loss ,sign in zip (outs ,outsigns ):
68+ h = (gain / 2 )* np .tan (inangle / 180. * np .pi ) # Dip depth
69+ move , (x , y ) = path [- 1 ] # Use last point as reference
70+ if sign == 0 : # First gain (horizontal)
71+ path .extend ([(Path .LINETO , [x - dx , y ]),
72+ (Path .LINETO , [x - dx + h , y + gain / 2 ]), # Dip
73+ (Path .LINETO , [x - dx , y + gain ])])
74+ xd , yd = path [- 2 ][1 ] # Dip position
75+ indips .append ((sign , [xd - h , yd ]))
76+ else : # Intermediate gain (vertical)
77+ path .extend ([(Path .CURVE4 , [x - dx / 2 , y ]),
78+ (Path .CURVE4 , [x - dx , y ]),
79+ (Path .CURVE4 , [x - dx , y + sign * dy ]),
80+ (Path .LINETO , [x - dx - gain / 2 , y + sign * (dy - h )]), # Dip
81+ (Path .LINETO , [x - dx - gain , y + sign * dy ]),
82+ (Path .CURVE3 , [x - dx - gain , y - sign * gain ]),
83+ (Path .CURVE3 , [x - dx / 2 - gain , y - sign * gain ])])
84+ xd , yd = path [- 4 ][1 ] # Dip position
85+ indips .append ((sign , [xd , yd + sign * h ]))
86+
87+ outtips = [] # Output arrow tip dir. and positions
88+ urpath = [(Path .MOVETO , [0 , 100 ])] # 1st point of upper right path
89+ lrpath = [(Path .LINETO , [0 , 0 ])] # 1st point of lower right path
90+ for loss , sign in zip (outs , outsigns ):
9091 add_output (sign >= 0 and urpath or lrpath , loss , sign = sign )
9192
92- indips = [] # Input arrow tip dir. and positions
93- llpath = [(Path .LINETO ,[0 ,0 ])] # 1st point of lower left path
94- ulpath = [(Path .MOVETO ,[0 ,100 ])] # 1st point of upper left path
95- for gain ,sign in zip (ins ,insigns )[:: - 1 ] :
93+ indips = [] # Input arrow tip dir. and positions
94+ llpath = [(Path .LINETO , [0 , 0 ])] # 1st point of lower left path
95+ ulpath = [(Path .MOVETO , [0 , 100 ])] # 1st point of upper left path
96+ for gain , sign in reversed ( list ( zip (ins , insigns ))) :
9697 add_input (sign <= 0 and llpath or ulpath , gain , sign = sign )
9798
9899 def revert (path ):
99100 """A path is not just revertable by path[::-1] because of Bezier
100- curves."""
101+ curves."""
101102 rpath = []
102103 nextmove = Path .LINETO
103- for move ,pos in path [::- 1 ]:
104- rpath .append ((nextmove ,pos ))
104+ for move , pos in path [::- 1 ]:
105+ rpath .append ((nextmove , pos ))
105106 nextmove = move
106107 return rpath
107108
108109 # Concatenate subpathes in correct order
109110 path = urpath + revert (lrpath ) + llpath + revert (ulpath )
110111
111- codes ,verts = zip (* path )
112+ codes , verts = zip (* path )
112113 verts = np .array (verts )
113114
114115 # Path patch
115- path = Path (verts ,codes )
116+ path = Path (verts , codes )
116117 patch = mpatches .PathPatch (path , ** kwargs )
117118 ax .add_patch (patch )
118119
119- if False : # DEBUG
120+ if False : # DEBUG
120121 print ("urpath" , urpath )
121122 print ("lrpath" , revert (lrpath ))
122123 print ("llpath" , llpath )
123124 print ("ulpath" , revert (ulpath ))
124-
125- xs ,ys = zip (* verts )
126- ax .plot (xs ,ys ,'go-' )
125+ xs , ys = zip (* verts )
126+ ax .plot (xs , ys , 'go-' )
127127
128128 # Labels
129129
130- def set_labels (labels ,values ):
130+ def set_labels (labels , values ):
131131 """Set or check labels according to values."""
132- if labels == '' : # No labels
132+ if labels == '' : # No labels
133133 return labels
134- elif labels is None : # Default labels
135- return [ '%2d%%' % val for val in values ]
134+ elif labels is None : # Default labels
135+ return ['%2d%%' % val for val in values ]
136136 else :
137- assert len (labels )== len (values )
137+ assert len (labels ) == len (values )
138138 return labels
139139
140- def put_labels (labels ,positions ,output = True ):
140+ def put_labels (labels , positions , output = True ):
141141 """Put labels to positions."""
142142 texts = []
143143 lbls = output and labels or labels [::- 1 ]
144- for i ,label in enumerate (lbls ):
145- s ,(x ,y ) = positions [i ] # Label direction and position
146- if s == 0 :
147- t = ax .text (x + offset ,y , label ,
144+ for i , label in enumerate (lbls ):
145+ s , (x , y ) = positions [i ] # Label direction and position
146+ if s == 0 :
147+ t = ax .text (x + offset , y , label ,
148148 ha = output and 'left' or 'right' , va = 'center' )
149- elif s > 0 :
150- t = ax .text (x ,y + offset ,label , ha = 'center' , va = 'bottom' )
149+ elif s > 0 :
150+ t = ax .text (x , y + offset , label , ha = 'center' , va = 'bottom' )
151151 else :
152- t = ax .text (x ,y - offset ,label , ha = 'center' , va = 'top' )
152+ t = ax .text (x , y - offset , label , ha = 'center' , va = 'top' )
153153 texts .append (t )
154154 return texts
155155
@@ -160,32 +160,30 @@ def put_labels(labels,positions,output=True):
160160 intexts = put_labels (inlabels , indips , output = False )
161161
162162 # Axes management
163- ax .set_xlim (verts [:,0 ].min ()- dx , verts [:,0 ].max ()+ dx )
164- ax .set_ylim (verts [:,1 ].min ()- dy , verts [:,1 ].max ()+ dy )
163+ ax .set_xlim (verts [:, 0 ].min ()- dx , verts [:, 0 ].max ()+ dx )
164+ ax .set_ylim (verts [:, 1 ].min ()- dy , verts [:, 1 ].max ()+ dy )
165165 ax .set_aspect ('equal' , adjustable = 'datalim' )
166166
167- return patch ,[intexts ,outtexts ]
167+ return patch , [intexts , outtexts ]
168+
168169
169170if __name__ == '__main__' :
170171
171172 import matplotlib .pyplot as plt
172173
173- outputs = [10. ,- 20. ,5. ,15. ,- 10. ,40. ]
174- outlabels = ['First' ,'Second' ,'Third' ,'Fourth' ,'Fifth' ,'Hurray!' ]
175- outlabels = [ s + '\n %d%%' % abs (l ) for l ,s in zip (outputs ,outlabels ) ]
174+ outputs = [10. , - 20. , 5. , 15. , - 10. , 40. ]
175+ outlabels = ['First' , 'Second' , 'Third' , 'Fourth' , 'Fifth' , 'Hurray!' ]
176+ outlabels = [s + '\n %d%%' % abs (l ) for l , s in zip (outputs , outlabels )]
176177
177- inputs = [60. ,- 25. ,15. ]
178+ inputs = [60. , - 25. , 15. ]
178179
179180 fig = plt .figure ()
180- ax = fig .add_subplot (1 ,1 ,1 , xticks = [],yticks = [],
181- title = "Sankey diagram"
182- )
181+ ax = fig .add_subplot (1 , 1 , 1 , xticks = [], yticks = [], title = "Sankey diagram" )
183182
184- patch ,(intexts ,outtexts ) = sankey (ax , outputs = outputs , outlabels = outlabels ,
185- inputs = inputs , inlabels = None ,
186- fc = 'g' , alpha = 0.2 )
183+ patch , (intexts , outtexts ) = sankey (ax , outputs = outputs ,
184+ outlabels = outlabels , inputs = inputs ,
185+ inlabels = None , fc = 'g' , alpha = 0.2 )
187186 outtexts [1 ].set_color ('r' )
188187 outtexts [- 1 ].set_fontweight ('bold' )
189188
190189 plt .show ()
191-
0 commit comments