1+ """
2+ """
3+
14import scipy .interpolate as si
25import numpy as np
36from functools import reduce
47
5- # uncomment this to set the backend
6- # import matplotlib
7- # matplotlib.use('Qt4Agg')
88import matplotlib .pyplot as plt
99
1010
@@ -13,73 +13,107 @@ class TooFewPointsException(Exception):
1313
1414
1515class SplineFitter :
16- def __init__ (self , ax , pix_err = 1 ):
17- self .canvas = ax .get_figure ().canvas
18- self .cid = None
19- self .pt_lst = []
20- self .pt_plot = ax .plot ([], [], marker = 'o' ,
21- linestyle = 'none' , zorder = 5 )[0 ]
22- self .sp_plot = ax .plot ([], [], lw = 3 , color = 'r' )[0 ]
23- self .pix_err = pix_err
24- self .connect_sf ()
25-
26- def clear (self ):
27- '''Clears the points'''
28- self .pt_lst = []
29- self .redraw ()
16+ def click_event (self , event ):
17+ '''Extracts locations from the user interaction
3018
31- def connect_sf (self ):
32- if self .cid is None :
33- self .cid = self .canvas .mpl_connect ('button_press_event' ,
34- self .click_event )
19+ Intended to be subscribed to 'button_press_event'
3520
36- def disconnect_sf ( self ):
37- if self . cid is not None :
38- self . canvas . mpl_disconnect ( self . cid )
39- self . cid = None
21+ Parameters
22+ ----------
23+ event : MouseEvent
24+ The
4025
41- def click_event (self , event ):
42- ''' Extracts locations from the user'''
26+ '''
27+ # stash the last event for debugging!
28+ self .ev = event
29+ # if shift is down, clear and bail
4330 if event .key == 'shift' :
4431 self .clear ()
4532 return
33+ # if no x or y data, bail
4634 if event .xdata is None or event .ydata is None :
4735 return
36+ # if not in our Axes, bail
37+ if event .inaxes is not self .ax :
38+ return
39+ # if left-click, append to points list
4840 if event .button == 1 :
4941 self .pt_lst .append ((event .xdata , event .ydata ))
42+ # if right-click, remove the closest point
5043 elif event .button == 3 :
5144 self .remove_pt ((event .xdata , event .ydata ))
52- self .ev = event
45+
46+ # re-draw (if needed)
5347 self .redraw ()
5448
5549 def remove_pt (self , loc ):
50+ """Remove the nearest point.
51+
52+ Parameters
53+ ----------
54+ loc : Tuple[float, float]
55+ The x, y location of the cilck
56+ """
5657 if len (self .pt_lst ) > 0 :
5758 self .pt_lst .pop (np .argmin (list (map (lambda x :
5859 np .sqrt ((x [0 ] - loc [0 ]) ** 2 +
5960 (x [1 ] - loc [1 ]) ** 2 ),
6061 self .pt_lst ))))
6162
6263 def redraw (self ):
64+ """Redraw the canvas given the current set of points
65+ """
66+ # get the current selected points
67+ if len (self .pt_lst ) > 0 :
68+ x , y = zip (* self .pt_lst )
69+ else :
70+ x , y = [], []
71+ # and update the Line2D with the
72+ self .pt_plot .set_xdata (x )
73+ self .pt_plot .set_ydata (y )
74+
75+ # if we have more than 5 points, create a best-fit closed spline
6376 if len (self .pt_lst ) > 5 :
6477 SC = SplineCurve .from_pts (self .pt_lst , pix_err = self .pix_err )
6578 new_pts = SC .q_phi_to_xy (0 , np .linspace (0 , 2 * np .pi , 1000 ))
6679 center = SC .cntr
67- self .sp_plot .set_xdata (new_pts [0 ])
68- self .sp_plot .set_ydata (new_pts [1 ])
6980 self .pt_lst .sort (key = lambda x :
7081 np .arctan2 (x [1 ] - center [1 ], x [0 ] - center [0 ]))
7182 else :
72- self .sp_plot .set_xdata ([])
73- self .sp_plot .set_ydata ([])
74- if len (self .pt_lst ) > 0 :
75- x , y = zip (* self .pt_lst )
76- else :
77- x , y = [], []
78- self .pt_plot .set_xdata (x )
79- self .pt_plot .set_ydata (y )
83+ new_pts = ([], [])
84+
85+ # and update the data in the spline Line2D objcet
86+ self .sp_plot .set_xdata (new_pts [0 ])
87+ self .sp_plot .set_ydata (new_pts [1 ])
8088
8189 self .canvas .draw_idle ()
8290
91+ def __init__ (self , ax , pix_err = 1 ):
92+ self .canvas = ax .get_figure ().canvas
93+ self .ax = ax
94+ self .cid = None
95+ self .pt_lst = []
96+ self .pt_plot = ax .plot ([], [], marker = 'o' ,
97+ linestyle = 'none' , zorder = 5 )[0 ]
98+ self .sp_plot = ax .plot ([], [], lw = 3 , color = 'r' )[0 ]
99+ self .pix_err = pix_err
100+ self .connect_sf ()
101+
102+ def clear (self ):
103+ '''Clears the points'''
104+ self .pt_lst = []
105+ self .redraw ()
106+
107+ def connect_sf (self ):
108+ if self .cid is None :
109+ self .cid = self .canvas .mpl_connect ('button_press_event' ,
110+ self .click_event )
111+
112+ def disconnect_sf (self ):
113+ if self .cid is not None :
114+ self .canvas .mpl_disconnect (self .cid )
115+ self .cid = None
116+
83117 @property
84118 def points (self ):
85119 '''Returns the clicked points in the format the rest of the
@@ -269,7 +303,7 @@ def q_phi_to_xy(self, q, phi, cross=None):
269303
270304
271305fig , ax = plt .subplots ()
272- ax .set_title ('left-click to add points, right-click to remove' )
306+ ax .set_title ('left-click to add points, right-click to remove, shift-click to clear ' )
273307sp = SplineFitter (ax , .001 )
274308plt .show ()
275309
0 commit comments