@@ -4243,108 +4243,54 @@ def __init__(self, xyA, xyB, coordsA, coordsB=None,
4243
4243
# if True, draw annotation only if self.xy is inside the axes
4244
4244
self ._annotation_clip = None
4245
4245
4246
- def _get_xy (self , x , y , s , axes = None ):
4246
+ def _get_xy (self , xy , s , axes = None ):
4247
4247
"""Calculate the pixel position of given point."""
4248
+ s0 = s # For the error message, if needed.
4248
4249
if axes is None :
4249
4250
axes = self .axes
4251
+ xy = np .array (xy )
4252
+ if s in ["figure points" , "axes points" ]:
4253
+ xy *= self .figure .dpi / 72
4254
+ s = s .replace ("points" , "pixels" )
4255
+ elif s == "figure fraction" :
4256
+ s = self .figure .transFigure
4257
+ elif s == "axes fraction" :
4258
+ s = axes .transAxes
4259
+ x , y = xy
4250
4260
4251
4261
if s == 'data' :
4252
4262
trans = axes .transData
4253
4263
x = float (self .convert_xunits (x ))
4254
4264
y = float (self .convert_yunits (y ))
4255
4265
return trans .transform ((x , y ))
4256
4266
elif s == 'offset points' :
4257
- # convert the data point
4258
- dx , dy = self .xy
4259
-
4260
- # prevent recursion
4261
- if self .xycoords == 'offset points' :
4262
- return self ._get_xy (dx , dy , 'data' )
4263
-
4264
- dx , dy = self ._get_xy (dx , dy , self .xycoords )
4265
-
4266
- # convert the offset
4267
- dpi = self .figure .get_dpi ()
4268
- x *= dpi / 72.
4269
- y *= dpi / 72.
4270
-
4271
- # add the offset to the data point
4272
- x += dx
4273
- y += dy
4274
-
4275
- return x , y
4267
+ if self .xycoords == 'offset points' : # prevent recursion
4268
+ return self ._get_xy (self .xy , 'data' )
4269
+ return (
4270
+ self ._get_xy (self .xy , self .xycoords ) # converted data point
4271
+ + xy * self .figure .dpi / 72 ) # converted offset
4276
4272
elif s == 'polar' :
4277
4273
theta , r = x , y
4278
4274
x = r * np .cos (theta )
4279
4275
y = r * np .sin (theta )
4280
4276
trans = axes .transData
4281
4277
return trans .transform ((x , y ))
4282
- elif s == 'figure points' :
4283
- # points from the lower left corner of the figure
4284
- dpi = self .figure .dpi
4285
- l , b , w , h = self .figure .bbox .bounds
4286
- r = l + w
4287
- t = b + h
4288
-
4289
- x *= dpi / 72.
4290
- y *= dpi / 72.
4291
- if x < 0 :
4292
- x = r + x
4293
- if y < 0 :
4294
- y = t + y
4295
- return x , y
4296
4278
elif s == 'figure pixels' :
4297
4279
# pixels from the lower left corner of the figure
4298
- l , b , w , h = self .figure .bbox .bounds
4299
- r = l + w
4300
- t = b + h
4301
- if x < 0 :
4302
- x = r + x
4303
- if y < 0 :
4304
- y = t + y
4305
- return x , y
4306
- elif s == 'figure fraction' :
4307
- # (0, 0) is lower left, (1, 1) is upper right of figure
4308
- trans = self .figure .transFigure
4309
- return trans .transform ((x , y ))
4310
- elif s == 'axes points' :
4311
- # points from the lower left corner of the axes
4312
- dpi = self .figure .dpi
4313
- l , b , w , h = axes .bbox .bounds
4314
- r = l + w
4315
- t = b + h
4316
- if x < 0 :
4317
- x = r + x * dpi / 72.
4318
- else :
4319
- x = l + x * dpi / 72.
4320
- if y < 0 :
4321
- y = t + y * dpi / 72.
4322
- else :
4323
- y = b + y * dpi / 72.
4280
+ bb = self .figure .bbox
4281
+ x = bb .x0 + x if x >= 0 else bb .x1 + x
4282
+ y = bb .y0 + y if y >= 0 else bb .y1 + y
4324
4283
return x , y
4325
4284
elif s == 'axes pixels' :
4326
4285
# pixels from the lower left corner of the axes
4327
- l , b , w , h = axes .bbox .bounds
4328
- r = l + w
4329
- t = b + h
4330
- if x < 0 :
4331
- x = r + x
4332
- else :
4333
- x = l + x
4334
- if y < 0 :
4335
- y = t + y
4336
- else :
4337
- y = b + y
4286
+ bb = axes .bbox
4287
+ x = bb .x0 + x if x >= 0 else bb .x1 + x
4288
+ y = bb .y0 + y if y >= 0 else bb .y1 + y
4338
4289
return x , y
4339
- elif s == 'axes fraction' :
4340
- # (0, 0) is lower left, (1, 1) is upper right of axes
4341
- trans = axes .transAxes
4342
- return trans .transform ((x , y ))
4343
4290
elif isinstance (s , transforms .Transform ):
4344
- return s .transform (( x , y ) )
4291
+ return s .transform (xy )
4345
4292
else :
4346
- raise ValueError ("{} is not a valid coordinate "
4347
- "transformation." .format (s ))
4293
+ raise ValueError (f"{ s0 } is not a valid coordinate transformation" )
4348
4294
4349
4295
def set_annotation_clip (self , b ):
4350
4296
"""
@@ -4374,39 +4320,29 @@ def get_annotation_clip(self):
4374
4320
4375
4321
def get_path_in_displaycoord (self ):
4376
4322
"""Return the mutated path of the arrow in display coordinates."""
4377
-
4378
4323
dpi_cor = self .get_dpi_cor ()
4379
-
4380
- x , y = self .xy1
4381
- posA = self ._get_xy (x , y , self .coords1 , self .axesA )
4382
-
4383
- x , y = self .xy2
4384
- posB = self ._get_xy (x , y , self .coords2 , self .axesB )
4385
-
4386
- _path = self .get_connectionstyle ()(posA , posB ,
4387
- patchA = self .patchA ,
4388
- patchB = self .patchB ,
4389
- shrinkA = self .shrinkA * dpi_cor ,
4390
- shrinkB = self .shrinkB * dpi_cor
4391
- )
4392
-
4393
- _path , fillable = self .get_arrowstyle ()(
4394
- _path ,
4395
- self .get_mutation_scale () * dpi_cor ,
4396
- self .get_linewidth () * dpi_cor ,
4397
- self .get_mutation_aspect ()
4398
- )
4399
-
4400
- return _path , fillable
4324
+ posA = self ._get_xy (self .xy1 , self .coords1 , self .axesA )
4325
+ posB = self ._get_xy (self .xy2 , self .coords2 , self .axesB )
4326
+ path = self .get_connectionstyle ()(
4327
+ posA , posB ,
4328
+ patchA = self .patchA , patchB = self .patchB ,
4329
+ shrinkA = self .shrinkA * dpi_cor , shrinkB = self .shrinkB * dpi_cor ,
4330
+ )
4331
+ path , fillable = self .get_arrowstyle ()(
4332
+ path ,
4333
+ self .get_mutation_scale () * dpi_cor ,
4334
+ self .get_linewidth () * dpi_cor ,
4335
+ self .get_mutation_aspect ()
4336
+ )
4337
+ return path , fillable
4401
4338
4402
4339
def _check_xy (self , renderer ):
4403
4340
"""Check whether the annotation needs to be drawn."""
4404
4341
4405
4342
b = self .get_annotation_clip ()
4406
4343
4407
4344
if b or (b is None and self .coords1 == "data" ):
4408
- x , y = self .xy1
4409
- xy_pixel = self ._get_xy (x , y , self .coords1 , self .axesA )
4345
+ xy_pixel = self ._get_xy (self .xy1 , self .coords1 , self .axesA )
4410
4346
if self .axesA is None :
4411
4347
axes = self .axes
4412
4348
else :
@@ -4415,8 +4351,7 @@ def _check_xy(self, renderer):
4415
4351
return False
4416
4352
4417
4353
if b or (b is None and self .coords2 == "data" ):
4418
- x , y = self .xy2
4419
- xy_pixel = self ._get_xy (x , y , self .coords2 , self .axesB )
4354
+ xy_pixel = self ._get_xy (self .xy2 , self .coords2 , self .axesB )
4420
4355
if self .axesB is None :
4421
4356
axes = self .axes
4422
4357
else :
0 commit comments