9
9
10
10
11
11
"""
12
- from pylab import *
12
+ import matplotlib .pyplot as plt
13
+ import numpy as np
13
14
14
- rates_to_bases = {'r1' : 'AT' , 'r2' : 'TA' , 'r3' : 'GA' , 'r4' : 'AG' , 'r5' : 'CA' , 'r6' : 'AC' ,
15
- 'r7' : 'GT' , 'r8' : 'TG' , 'r9' : 'CT' , 'r10' : 'TC' , 'r11' : 'GC' , 'r12' : 'CG' }
15
+ rates_to_bases = {'r1' : 'AT' , 'r2' : 'TA' , 'r3' : 'GA' , 'r4' : 'AG' , 'r5' : 'CA' ,
16
+ 'r6' : 'AC' , 'r7' : 'GT' , 'r8' : 'TG' , 'r9' : 'CT' , 'r10' : 'TC' ,
17
+ 'r11' : 'GC' , 'r12' : 'CG' }
16
18
numbered_bases_to_rates = dict ([(v , k ) for k , v in rates_to_bases .items ()])
17
19
lettered_bases_to_rates = dict ([(v , 'r' + v ) for k , v in rates_to_bases .items ()])
18
20
@@ -45,17 +47,17 @@ def make_arrow_plot(data, size=4, display='length', shape='right',
45
47
linewidth and edgecolor.
46
48
"""
47
49
48
- xlim (- 0.5 , 1.5 )
49
- ylim (- 0.5 , 1.5 )
50
- gcf ().set_size_inches (size , size )
51
- xticks ([])
52
- yticks ([])
50
+ plt . xlim (- 0.5 , 1.5 )
51
+ plt . ylim (- 0.5 , 1.5 )
52
+ plt . gcf ().set_size_inches (size , size )
53
+ plt . xticks ([])
54
+ plt . yticks ([])
53
55
max_text_size = size * 12
54
56
min_text_size = size
55
57
label_text_size = size * 2.5
56
58
text_params = {'ha' : 'center' , 'va' : 'center' , 'family' : 'sans-serif' ,
57
59
'fontweight' : 'bold' }
58
- r2 = sqrt (2 )
60
+ r2 = np . sqrt (2 )
59
61
60
62
deltas = {
61
63
'AT' : (1 , 0 ),
@@ -103,13 +105,13 @@ def make_arrow_plot(data, size=4, display='length', shape='right',
103
105
}
104
106
105
107
def do_fontsize (k ):
106
- return float (clip (max_text_size * sqrt (data [k ]),
108
+ return float (np . clip (max_text_size * np . sqrt (data [k ]),
107
109
min_text_size , max_text_size ))
108
110
109
- A = text (0 , 1 , '$A_3$' , color = 'r' , size = do_fontsize ('A' ), ** text_params )
110
- T = text (1 , 1 , '$T_3$' , color = 'k' , size = do_fontsize ('T' ), ** text_params )
111
- G = text (0 , 0 , '$G_3$' , color = 'g' , size = do_fontsize ('G' ), ** text_params )
112
- C = text (1 , 0 , '$C_3$' , color = 'b' , size = do_fontsize ('C' ), ** text_params )
111
+ A = plt . text (0 , 1 , '$A_3$' , color = 'r' , size = do_fontsize ('A' ), ** text_params )
112
+ T = plt . text (1 , 1 , '$T_3$' , color = 'k' , size = do_fontsize ('T' ), ** text_params )
113
+ G = plt . text (0 , 0 , '$G_3$' , color = 'g' , size = do_fontsize ('G' ), ** text_params )
114
+ C = plt . text (1 , 0 , '$C_3$' , color = 'b' , size = do_fontsize ('C' ), ** text_params )
113
115
114
116
arrow_h_offset = 0.25 # data coordinates, empirically determined
115
117
max_arrow_length = 1 - 2 * arrow_h_offset
@@ -119,7 +121,7 @@ def do_fontsize(k):
119
121
max_head_length = 2 * max_arrow_width
120
122
arrow_params = {'length_includes_head' : True , 'shape' : shape ,
121
123
'head_starts_at_zero' : head_starts_at_zero }
122
- ax = gca ()
124
+ ax = plt . gca ()
123
125
sf = 0.6 # max arrow size represents this in data coords
124
126
125
127
d = (r2 / 2 + arrow_h_offset - 0.5 )/ r2 # distance for diags
@@ -179,7 +181,7 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
179
181
180
182
x_scale , y_scale = deltas [pair ]
181
183
x_pos , y_pos = positions [pair ]
182
- arrow (x_pos , y_pos , x_scale * length , y_scale * length ,
184
+ plt . arrow (x_pos , y_pos , x_scale * length , y_scale * length ,
183
185
fc = fc , ec = ec , alpha = alpha , width = width , head_width = head_width ,
184
186
head_length = head_length , ** arrow_params )
185
187
@@ -192,24 +194,24 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
192
194
193
195
where = label_positions [pair ]
194
196
if where == 'left' :
195
- orig_position = 3 * array ([[max_arrow_width , max_arrow_width ]])
197
+ orig_position = 3 * np . array ([[max_arrow_width , max_arrow_width ]])
196
198
elif where == 'absolute' :
197
- orig_position = array ([[max_arrow_length / 2.0 , 3 * max_arrow_width ]])
199
+ orig_position = np . array ([[max_arrow_length / 2.0 , 3 * max_arrow_width ]])
198
200
elif where == 'right' :
199
- orig_position = array ([[length - 3 * max_arrow_width ,
201
+ orig_position = np . array ([[length - 3 * max_arrow_width ,
200
202
3 * max_arrow_width ]])
201
203
elif where == 'center' :
202
- orig_position = array ([[length / 2.0 , 3 * max_arrow_width ]])
204
+ orig_position = np . array ([[length / 2.0 , 3 * max_arrow_width ]])
203
205
else :
204
206
raise ValueError ("Got unknown position parameter %s" % where )
205
207
206
- M = array ([[cx , sx ], [- sx , cx ]])
207
- coords = dot (orig_position , M ) + [[x_pos , y_pos ]]
208
- x , y = ravel (coords )
208
+ M = np . array ([[cx , sx ], [- sx , cx ]])
209
+ coords = np . dot (orig_position , M ) + [[x_pos , y_pos ]]
210
+ x , y = np . ravel (coords )
209
211
orig_label = rate_labels [pair ]
210
212
label = '$%s_{_{\mathrm{%s}}}$' % (orig_label [0 ], orig_label [1 :])
211
213
212
- text (x , y , label , size = label_text_size , ha = 'center' , va = 'center' ,
214
+ plt . text (x , y , label , size = label_text_size , ha = 'center' , va = 'center' ,
213
215
color = labelcolor or fc )
214
216
215
217
for p in positions .keys ():
@@ -302,11 +304,11 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
302
304
display = 'length'
303
305
304
306
size = 4
305
- figure (figsize = (size , size ))
307
+ plt . figure (figsize = (size , size ))
306
308
307
309
make_arrow_plot (d , display = display , linewidth = 0.001 , edgecolor = None ,
308
310
normalize_data = scaled , head_starts_at_zero = True , size = size )
309
311
310
- draw ()
312
+ plt . draw ()
311
313
312
- show ()
314
+ plt . show ()
0 commit comments