Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b8d7e25

Browse files
committed
Merge pull request #4660 from domspad/MEP12-on-arrow_demo.py
MEP12-on-arrow_demo.py
2 parents d1d0122 + b26a521 commit b8d7e25

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

examples/pylab_examples/arrow_demo.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
1010
1111
"""
12-
from pylab import *
12+
import matplotlib.pyplot as plt
13+
import numpy as np
1314

14-
rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA', 'r6': 'AC',
15-
'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC', 'r11': 'GC', 'r12': 'CG'}
15+
rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA',
16+
'r6': 'AC', 'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC',
17+
'r11': 'GC', 'r12': 'CG'}
1618
numbered_bases_to_rates = dict([(v, k) for k, v in rates_to_bases.items()])
1719
lettered_bases_to_rates = dict([(v, 'r' + v) for k, v in rates_to_bases.items()])
1820

@@ -45,17 +47,17 @@ def make_arrow_plot(data, size=4, display='length', shape='right',
4547
linewidth and edgecolor.
4648
"""
4749

48-
xlim(-0.5, 1.5)
49-
ylim(-0.5, 1.5)
50-
gcf().set_size_inches(size, size)
51-
xticks([])
52-
yticks([])
50+
plt.xlim(-0.5, 1.5)
51+
plt.ylim(-0.5, 1.5)
52+
plt.gcf().set_size_inches(size, size)
53+
plt.xticks([])
54+
plt.yticks([])
5355
max_text_size = size*12
5456
min_text_size = size
5557
label_text_size = size*2.5
5658
text_params = {'ha': 'center', 'va': 'center', 'family': 'sans-serif',
5759
'fontweight': 'bold'}
58-
r2 = sqrt(2)
60+
r2 = np.sqrt(2)
5961

6062
deltas = {
6163
'AT': (1, 0),
@@ -103,13 +105,13 @@ def make_arrow_plot(data, size=4, display='length', shape='right',
103105
}
104106

105107
def do_fontsize(k):
106-
return float(clip(max_text_size*sqrt(data[k]),
108+
return float(np.clip(max_text_size*np.sqrt(data[k]),
107109
min_text_size, max_text_size))
108110

109-
A = text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
110-
T = text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
111-
G = text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
112-
C = text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
111+
A = plt.text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
112+
T = plt.text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
113+
G = plt.text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
114+
C = plt.text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
113115

114116
arrow_h_offset = 0.25 # data coordinates, empirically determined
115117
max_arrow_length = 1 - 2*arrow_h_offset
@@ -119,7 +121,7 @@ def do_fontsize(k):
119121
max_head_length = 2*max_arrow_width
120122
arrow_params = {'length_includes_head': True, 'shape': shape,
121123
'head_starts_at_zero': head_starts_at_zero}
122-
ax = gca()
124+
ax = plt.gca()
123125
sf = 0.6 # max arrow size represents this in data coords
124126

125127
d = (r2/2 + arrow_h_offset - 0.5)/r2 # distance for diags
@@ -179,7 +181,7 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
179181

180182
x_scale, y_scale = deltas[pair]
181183
x_pos, y_pos = positions[pair]
182-
arrow(x_pos, y_pos, x_scale*length, y_scale*length,
184+
plt.arrow(x_pos, y_pos, x_scale*length, y_scale*length,
183185
fc=fc, ec=ec, alpha=alpha, width=width, head_width=head_width,
184186
head_length=head_length, **arrow_params)
185187

@@ -192,24 +194,24 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
192194

193195
where = label_positions[pair]
194196
if where == 'left':
195-
orig_position = 3*array([[max_arrow_width, max_arrow_width]])
197+
orig_position = 3*np.array([[max_arrow_width, max_arrow_width]])
196198
elif where == 'absolute':
197-
orig_position = array([[max_arrow_length/2.0, 3*max_arrow_width]])
199+
orig_position = np.array([[max_arrow_length/2.0, 3*max_arrow_width]])
198200
elif where == 'right':
199-
orig_position = array([[length - 3*max_arrow_width,
201+
orig_position = np.array([[length - 3*max_arrow_width,
200202
3*max_arrow_width]])
201203
elif where == 'center':
202-
orig_position = array([[length/2.0, 3*max_arrow_width]])
204+
orig_position = np.array([[length/2.0, 3*max_arrow_width]])
203205
else:
204206
raise ValueError("Got unknown position parameter %s" % where)
205207

206-
M = array([[cx, sx], [-sx, cx]])
207-
coords = dot(orig_position, M) + [[x_pos, y_pos]]
208-
x, y = ravel(coords)
208+
M = np.array([[cx, sx], [-sx, cx]])
209+
coords = np.dot(orig_position, M) + [[x_pos, y_pos]]
210+
x, y = np.ravel(coords)
209211
orig_label = rate_labels[pair]
210212
label = '$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:])
211213

212-
text(x, y, label, size=label_text_size, ha='center', va='center',
214+
plt.text(x, y, label, size=label_text_size, ha='center', va='center',
213215
color=labelcolor or fc)
214216

215217
for p in positions.keys():
@@ -302,11 +304,11 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
302304
display = 'length'
303305

304306
size = 4
305-
figure(figsize=(size, size))
307+
plt.figure(figsize=(size, size))
306308

307309
make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None,
308310
normalize_data=scaled, head_starts_at_zero=True, size=size)
309311

310-
draw()
312+
plt.draw()
311313

312-
show()
314+
plt.show()

0 commit comments

Comments
 (0)