@@ -503,6 +503,31 @@ def is_scalar_or_string(val):
503
503
return isinstance (val , str ) or not np .iterable (val )
504
504
505
505
506
+ def duplicate_if_scalar (obj , n = 2 , raises = True ):
507
+ """Ensure object size or duplicate into a list if necessary."""
508
+
509
+ if is_scalar_or_string (obj ):
510
+ return [obj ] * n
511
+
512
+ size = len (obj )
513
+ if size == 0 :
514
+ if raises :
515
+ raise ValueError (f'Cannot duplicate empty { type (obj )} .' )
516
+ return [obj ] * n
517
+
518
+ if size == 1 :
519
+ return list (obj ) * n
520
+
521
+ if (size != n ) and raises :
522
+ raise ValueError (
523
+ f'Input object of type { type (obj )} has incorrect size. Expected '
524
+ f'either a scalar type object, or a Container with length in {{1, '
525
+ f'{ n } }}.'
526
+ )
527
+
528
+ return obj
529
+
530
+
506
531
@_api .delete_parameter (
507
532
"3.8" , "np_load" , alternative = "open(get_sample_data(..., asfileobj=False))" )
508
533
def get_sample_data (fname , asfileobj = True , * , np_load = True ):
@@ -567,6 +592,23 @@ def flatten(seq, scalarp=is_scalar_or_string):
567
592
yield from flatten (item , scalarp )
568
593
569
594
595
+ def pairwise (iterable ):
596
+ """
597
+ Returns an iterator of paired items, overlapping, from the original
598
+
599
+ take(4, pairwise(count()))
600
+ [(0, 1), (1, 2), (2, 3), (3, 4)]
601
+
602
+ From more_itertools:
603
+ https://more-itertools.readthedocs.io/en/stable/_modules/more_itertools/recipes.html#pairwise
604
+
605
+ Can be removed on python >3.10 in favour of itertools.pairwise
606
+ """
607
+ a , b = itertools .tee (iterable )
608
+ next (b , None )
609
+ return zip (a , b )
610
+
611
+
570
612
@_api .deprecated ("3.8" )
571
613
class Stack :
572
614
"""
@@ -1473,6 +1515,120 @@ def _reshape_2D(X, name):
1473
1515
return result
1474
1516
1475
1517
1518
+ def hexbin (x , y , C = None , gridsize = 100 ,
1519
+ xscale = 'linear' , yscale = 'linear' , extent = None ,
1520
+ reduce_C_function = np .mean , mincnt = None ):
1521
+
1522
+ # local import to avoid circular import
1523
+ import matplotlib .transforms as mtransforms
1524
+
1525
+ # Set the size of the hexagon grid
1526
+ if np .iterable (gridsize ):
1527
+ nx , ny = gridsize
1528
+ else :
1529
+ nx = gridsize
1530
+ ny = int (nx / math .sqrt (3 ))
1531
+
1532
+ # Will be log()'d if necessary, and then rescaled.
1533
+ tx = x
1534
+ ty = y
1535
+
1536
+ if xscale == 'log' :
1537
+ if np .any (x <= 0.0 ):
1538
+ raise ValueError (
1539
+ "x contains non-positive values, so cannot be log-scaled" )
1540
+ tx = np .log10 (tx )
1541
+ if yscale == 'log' :
1542
+ if np .any (y <= 0.0 ):
1543
+ raise ValueError (
1544
+ "y contains non-positive values, so cannot be log-scaled" )
1545
+ ty = np .log10 (ty )
1546
+ if extent is not None :
1547
+ xmin , xmax , ymin , ymax = extent
1548
+ if xmin > xmax :
1549
+ raise ValueError ("In extent, xmax must be greater than xmin" )
1550
+ if ymin > ymax :
1551
+ raise ValueError ("In extent, ymax must be greater than ymin" )
1552
+ else :
1553
+ xmin , xmax = (tx .min (), tx .max ()) if len (x ) else (0 , 1 )
1554
+ ymin , ymax = (ty .min (), ty .max ()) if len (y ) else (0 , 1 )
1555
+
1556
+ # to avoid issues with singular data, expand the min/max pairs
1557
+ xmin , xmax = mtransforms .nonsingular (xmin , xmax , expander = 0.1 )
1558
+ ymin , ymax = mtransforms .nonsingular (ymin , ymax , expander = 0.1 )
1559
+
1560
+ nx1 = nx + 1
1561
+ ny1 = ny + 1
1562
+ nx2 = nx
1563
+ ny2 = ny
1564
+ n = nx1 * ny1 + nx2 * ny2
1565
+
1566
+ # In the x-direction, the hexagons exactly cover the region from
1567
+ # xmin to xmax. Need some padding to avoid roundoff errors.
1568
+ padding = 1.e-9 * (xmax - xmin )
1569
+ xmin -= padding
1570
+ xmax += padding
1571
+ sx = (xmax - xmin ) / nx
1572
+ sy = (ymax - ymin ) / ny
1573
+ # Positions in hexagon index coordinates.
1574
+ ix = (tx - xmin ) / sx
1575
+ iy = (ty - ymin ) / sy
1576
+ ix1 = np .round (ix ).astype (int )
1577
+ iy1 = np .round (iy ).astype (int )
1578
+ ix2 = np .floor (ix ).astype (int )
1579
+ iy2 = np .floor (iy ).astype (int )
1580
+ # flat indices, plus one so that out-of-range points go to position 0.
1581
+ i1 = np .where ((0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ),
1582
+ ix1 * ny1 + iy1 + 1 , 0 )
1583
+ i2 = np .where ((0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ),
1584
+ ix2 * ny2 + iy2 + 1 , 0 )
1585
+
1586
+ d1 = (ix - ix1 ) ** 2 + 3.0 * (iy - iy1 ) ** 2
1587
+ d2 = (ix - ix2 - 0.5 ) ** 2 + 3.0 * (iy - iy2 - 0.5 ) ** 2
1588
+ bdist = (d1 < d2 )
1589
+
1590
+ if C is None : # [1:] drops out-of-range points.
1591
+ counts1 = np .bincount (i1 [bdist ], minlength = 1 + nx1 * ny1 )[1 :]
1592
+ counts2 = np .bincount (i2 [~ bdist ], minlength = 1 + nx2 * ny2 )[1 :]
1593
+ accum = np .concatenate ([counts1 , counts2 ]).astype (float )
1594
+ if mincnt is not None :
1595
+ accum [accum < mincnt ] = np .nan
1596
+
1597
+ else :
1598
+ # store the C values in a list per hexagon index
1599
+ Cs_at_i1 = [[] for _ in range (1 + nx1 * ny1 )]
1600
+ Cs_at_i2 = [[] for _ in range (1 + nx2 * ny2 )]
1601
+ for i in range (len (x )):
1602
+ if bdist [i ]:
1603
+ Cs_at_i1 [i1 [i ]].append (C [i ])
1604
+ else :
1605
+ Cs_at_i2 [i2 [i ]].append (C [i ])
1606
+ if mincnt is None :
1607
+ mincnt = 1
1608
+ accum = np .array (
1609
+ [reduce_C_function (acc ) if len (acc ) >= mincnt else np .nan
1610
+ for Cs_at_i in [Cs_at_i1 , Cs_at_i2 ]
1611
+ for acc in Cs_at_i [1 :]], # [1:] drops out-of-range points.
1612
+ float )
1613
+
1614
+ good_idxs = ~ np .isnan (accum )
1615
+
1616
+ offsets = np .zeros ((n , 2 ), float )
1617
+ offsets [:nx1 * ny1 , 0 ] = np .repeat (np .arange (nx1 ), ny1 )
1618
+ offsets [:nx1 * ny1 , 1 ] = np .tile (np .arange (ny1 ), nx1 )
1619
+ offsets [nx1 * ny1 :, 0 ] = np .repeat (np .arange (nx2 ) + 0.5 , ny2 )
1620
+ offsets [nx1 * ny1 :, 1 ] = np .tile (np .arange (ny2 ), nx2 ) + 0.5
1621
+ offsets [:, 0 ] *= sx
1622
+ offsets [:, 1 ] *= sy
1623
+ offsets [:, 0 ] += xmin
1624
+ offsets [:, 1 ] += ymin
1625
+ # remove accumulation bins with no data
1626
+ offsets = offsets [good_idxs , :]
1627
+ accum = accum [good_idxs ]
1628
+
1629
+ return (* offsets .T , accum ), (xmin , xmax ), (ymin , ymax ), (nx , ny )
1630
+
1631
+
1476
1632
def violin_stats (X , method , points = 100 , quantiles = None ):
1477
1633
"""
1478
1634
Return a list of dictionaries of data which can be used to draw a series
0 commit comments