@@ -1002,43 +1002,24 @@ def __init__(self, ax, labels, actives=None):
1002
1002
if actives is None :
1003
1003
actives = [False ] * len (labels )
1004
1004
1005
- if len (labels ) > 1 :
1006
- dy = 1. / (len (labels ) + 1 )
1007
- ys = np .linspace (1 - dy , dy , len (labels ))
1008
- else :
1009
- dy = 0.25
1010
- ys = [0.5 ]
1011
-
1012
- axcolor = ax .get_facecolor ()
1013
-
1014
- self .labels = []
1015
- self .lines = []
1016
- self .rectangles = []
1017
-
1018
- lineparams = {'color' : 'k' , 'linewidth' : 1.25 ,
1019
- 'transform' : ax .transAxes , 'solid_capstyle' : 'butt' }
1020
- for y , label , active in zip (ys , labels , actives ):
1021
- t = ax .text (0.25 , y , label , transform = ax .transAxes ,
1022
- horizontalalignment = 'left' ,
1023
- verticalalignment = 'center' )
1024
-
1025
- w , h = dy / 2 , dy / 2
1026
- x , y = 0.05 , y - h / 2
1027
-
1028
- p = Rectangle (xy = (x , y ), width = w , height = h , edgecolor = 'black' ,
1029
- facecolor = axcolor , transform = ax .transAxes )
1005
+ ys = np .linspace (1 , 0 , len (labels )+ 2 )[1 :- 1 ]
1006
+ text_size = mpl .rcParams ["font.size" ] / 2
1030
1007
1031
- l1 = Line2D ([x , x + w ], [y + h , y ], ** lineparams )
1032
- l2 = Line2D ([x , x + w ], [y , y + h ], ** lineparams )
1008
+ self .labels = [
1009
+ ax .text (0.25 , y , label , transform = ax .transAxes ,
1010
+ horizontalalignment = "left" , verticalalignment = "center" )
1011
+ for y , label in zip (ys , labels )]
1033
1012
1034
- l1 .set_visible (active )
1035
- l2 .set_visible (active )
1036
- self .labels .append (t )
1037
- self .rectangles .append (p )
1038
- self .lines .append ((l1 , l2 ))
1039
- ax .add_patch (p )
1040
- ax .add_line (l1 )
1041
- ax .add_line (l2 )
1013
+ self ._squares = ax .scatter (
1014
+ [0.15 ] * len (ys ), ys , marker = 's' , c = "none" , linewidth = 1 ,
1015
+ transform = ax .transAxes , edgecolor = "k"
1016
+ )
1017
+ mask = [not x for x in actives ]
1018
+ self ._crosses = ax .scatter (
1019
+ [0.15 ] * len (ys ), ys , marker = 'x' , linewidth = 1 ,
1020
+ c = ["k" if actives [i ] else "none" for i in range (len (ys ))],
1021
+ transform = ax .transAxes
1022
+ )
1042
1023
1043
1024
self .connect_event ('button_press_event' , self ._clicked )
1044
1025
@@ -1047,11 +1028,29 @@ def __init__(self, ax, labels, actives=None):
1047
1028
def _clicked (self , event ):
1048
1029
if self .ignore (event ) or event .button != 1 or event .inaxes != self .ax :
1049
1030
return
1050
- for i , (p , t ) in enumerate (zip (self .rectangles , self .labels )):
1051
- if (t .get_window_extent ().contains (event .x , event .y ) or
1052
- p .get_window_extent ().contains (event .x , event .y )):
1053
- self .set_active (i )
1054
- break
1031
+ pclicked = self .ax .transAxes .inverted ().transform ((event .x , event .y ))
1032
+ _ , square_inds = self ._squares .contains (event )
1033
+ coords = self ._squares .get_offset_transform ().transform (
1034
+ self ._squares .get_offsets ()
1035
+ )
1036
+ distances = {}
1037
+ if hasattr (self , "_rectangles" ):
1038
+ for i , (p , t ) in enumerate (zip (self ._rectangles , self .labels )):
1039
+ if (t .get_window_extent ().contains (event .x , event .y )
1040
+ or (
1041
+ p .get_x () < event .x < p .get_x () + p .get_width ()
1042
+ and p .get_y () < event .y < p .get_y ()
1043
+ + p .get_height ()
1044
+ )):
1045
+ distances [i ] = np .linalg .norm (pclicked - p .get_center ())
1046
+ else :
1047
+ for i , t in enumerate (self .labels ):
1048
+ if (i in square_inds ["ind" ]
1049
+ or t .get_window_extent ().contains (event .x , event .y )):
1050
+ distances [i ] = np .linalg .norm (pclicked - coords [i ])
1051
+ if len (distances ) > 0 :
1052
+ closest = min (distances , key = distances .get )
1053
+ self .set_active (closest )
1055
1054
1056
1055
def set_active (self , index ):
1057
1056
"""
@@ -1072,9 +1071,18 @@ def set_active(self, index):
1072
1071
if index not in range (len (self .labels )):
1073
1072
raise ValueError (f'Invalid CheckButton index: { index } ' )
1074
1073
1075
- l1 , l2 = self .lines [index ]
1076
- l1 .set_visible (not l1 .get_visible ())
1077
- l2 .set_visible (not l2 .get_visible ())
1074
+ if colors .same_color (
1075
+ self ._crosses .get_facecolor ()[index ], colors .to_rgba ("none" )
1076
+ ):
1077
+ self ._crosses .get_facecolor ()[index ] = colors .to_rgba ("k" )
1078
+ else :
1079
+ self ._crosses .get_facecolor ()[index ] = colors .to_rgba ("none" )
1080
+
1081
+ if hasattr (self , "_rectangles" ):
1082
+ for i , p in enumerate (self ._rectangles ):
1083
+ p .set_facecolor ("k" if colors .same_color (
1084
+ p .get_facecolor (), colors .to_rgba ("none" ))
1085
+ else "none" )
1078
1086
1079
1087
if self .drawon :
1080
1088
self .ax .figure .canvas .draw ()
@@ -1086,7 +1094,9 @@ def get_status(self):
1086
1094
"""
1087
1095
Return a tuple of the status (True/False) of all of the check buttons.
1088
1096
"""
1089
- return [l1 .get_visible () for (l1 , l2 ) in self .lines ]
1097
+ return [False if colors .same_color (
1098
+ self ._crosses .get_facecolors ()[i ], colors .to_rgba ("none" ))
1099
+ else True for i in range (len (self .labels ))]
1090
1100
1091
1101
def on_clicked (self , func ):
1092
1102
"""
@@ -1100,6 +1110,24 @@ def disconnect(self, cid):
1100
1110
"""Remove the observer with connection id *cid*."""
1101
1111
self ._observers .disconnect (cid )
1102
1112
1113
+ @property
1114
+ def rectangles (self ):
1115
+ if not hasattr (self , "rectangles" ):
1116
+ dy = 1. / (len (self .labels ) + 1 )
1117
+ w , h = dy / 2 , dy / 2
1118
+ rectangles = self ._rectangles = [
1119
+ Rectangle (xy = self ._squares .get_offsets ()[i ], width = w , height = h ,
1120
+ edgecolor = "black" ,
1121
+ facecolor = self ._squares .get_facecolor ()[i ],
1122
+ transform = self .ax .transAxes
1123
+ )
1124
+ for i in range (len (self .labels ))
1125
+ ]
1126
+ self ._squares .set_visible (False )
1127
+ for rectangle in rectangles :
1128
+ self .ax .add_patch (rectangle )
1129
+ return self ._rectangles
1130
+
1103
1131
1104
1132
class TextBox (AxesWidget ):
1105
1133
"""
0 commit comments