@@ -4606,33 +4606,39 @@ def reduce_C_function(C: array) -> float
46064606 nx = gridsize
46074607 ny = int (nx / math .sqrt (3 ))
46084608 # Count the number of data in each hexagon
4609- x = np .array (x , float )
4610- y = np .array (y , float )
4609+ x = np .asarray (x , float )
4610+ y = np .asarray (y , float )
46114611
4612- if marginals :
4613- xorig = x . copy ()
4614- yorig = y . copy ()
4612+ # Will be log()'d if necessary, and then rescaled.
4613+ trfx = x
4614+ trfy = y
46154615
46164616 if xscale == 'log' :
46174617 if np .any (x <= 0.0 ):
4618- raise ValueError ("x contains non-positive values, so can not"
4619- " be log-scaled" )
4620- x = np .log10 (x )
4618+ raise ValueError ("x contains non-positive values, so can not "
4619+ "be log-scaled" )
4620+ trfx = np .log10 (trfx )
46214621 if yscale == 'log' :
46224622 if np .any (y <= 0.0 ):
4623- raise ValueError ("y contains non-positive values, so can not"
4624- " be log-scaled" )
4625- y = np .log10 (y )
4623+ raise ValueError ("y contains non-positive values, so can not "
4624+ "be log-scaled" )
4625+ trfy = np .log10 (trfy )
46264626 if extent is not None :
46274627 xmin , xmax , ymin , ymax = extent
46284628 else :
4629- xmin , xmax = (np .min (x ), np .max (x )) if len (x ) else (0 , 1 )
4630- ymin , ymax = (np .min (y ), np .max (y )) if len (y ) else (0 , 1 )
4629+ xmin , xmax = (np .min (trfx ), np .max (trfx )) if len (x ) else (0 , 1 )
4630+ ymin , ymax = (np .min (trfy ), np .max (trfy )) if len (y ) else (0 , 1 )
46314631
46324632 # to avoid issues with singular data, expand the min/max pairs
46334633 xmin , xmax = mtransforms .nonsingular (xmin , xmax , expander = 0.1 )
46344634 ymin , ymax = mtransforms .nonsingular (ymin , ymax , expander = 0.1 )
46354635
4636+ nx1 = nx + 1
4637+ ny1 = ny + 1
4638+ nx2 = nx
4639+ ny2 = ny
4640+ n = nx1 * ny1 + nx2 * ny2
4641+
46364642 # In the x-direction, the hexagons exactly cover the region from
46374643 # xmin to xmax. Need some padding to avoid roundoff errors.
46384644 padding = 1.e-9 * (xmax - xmin )
@@ -4641,75 +4647,49 @@ def reduce_C_function(C: array) -> float
46414647 sx = (xmax - xmin ) / nx
46424648 sy = (ymax - ymin ) / ny
46434649
4644- x = (x - xmin ) / sx
4645- y = (y - ymin ) / sy
4646- ix1 = np .round (x ).astype (int )
4647- iy1 = np .round (y ).astype (int )
4648- ix2 = np .floor (x ).astype (int )
4649- iy2 = np .floor (y ).astype (int )
4650-
4651- nx1 = nx + 1
4652- ny1 = ny + 1
4653- nx2 = nx
4654- ny2 = ny
4655- n = nx1 * ny1 + nx2 * ny2
4656-
4657- d1 = (x - ix1 ) ** 2 + 3.0 * (y - iy1 ) ** 2
4658- d2 = (x - ix2 - 0.5 ) ** 2 + 3.0 * (y - iy2 - 0.5 ) ** 2
4650+ trfx = (trfx - xmin ) / sx
4651+ trfy = (trfy - ymin ) / sy
4652+ ix1 = np .round (trfx ).astype (int )
4653+ iy1 = np .round (trfy ).astype (int )
4654+ ix2 = np .floor (trfx ).astype (int )
4655+ iy2 = np .floor (trfy ).astype (int )
4656+ # flat indices, plus one so that out-of-range points go to position 0.
4657+ i1 = np .where ((0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ),
4658+ ix1 * ny1 + iy1 + 1 , 0 )
4659+ i2 = np .where ((0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ),
4660+ ix2 * ny2 + iy2 + 1 , 0 )
4661+
4662+ d1 = (trfx - ix1 ) ** 2 + 3.0 * (trfy - iy1 ) ** 2
4663+ d2 = (trfx - ix2 - 0.5 ) ** 2 + 3.0 * (trfy - iy2 - 0.5 ) ** 2
46594664 bdist = (d1 < d2 )
4665+
46604666 if C is None :
4661- lattice1 = np .zeros ((nx1 , ny1 ))
4662- lattice2 = np .zeros ((nx2 , ny2 ))
4663- c1 = (0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ) & bdist
4664- c2 = (0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ) & ~ bdist
4665- np .add .at (lattice1 , (ix1 [c1 ], iy1 [c1 ]), 1 )
4666- np .add .at (lattice2 , (ix2 [c2 ], iy2 [c2 ]), 1 )
4667+ lattice1 = np .bincount (i1 [bdist ], minlength = 1 + nx1 * ny1 )
4668+ lattice2 = np .bincount (i2 [~ bdist ], minlength = 1 + nx2 * ny2 )
4669+ accum = np .concatenate ( # [1:] drops out-of-range points.
4670+ [lattice1 .ravel ()[1 :], lattice2 .ravel ()[1 :]]).astype (float )
46674671 if mincnt is not None :
4668- lattice1 [lattice1 < mincnt ] = np .nan
4669- lattice2 [lattice2 < mincnt ] = np .nan
4670- accum = np .concatenate ([lattice1 .ravel (), lattice2 .ravel ()])
4671- good_idxs = ~ np .isnan (accum )
4672+ accum [accum < mincnt ] = np .nan
4673+ C = np .ones (len (x ))
46724674
46734675 else :
4674- if mincnt is None :
4675- mincnt = 0
4676-
4677- # create accumulation arrays
4678- lattice1 = np .empty ((nx1 , ny1 ), dtype = object )
4679- for i in range (nx1 ):
4680- for j in range (ny1 ):
4681- lattice1 [i , j ] = []
4682- lattice2 = np .empty ((nx2 , ny2 ), dtype = object )
4683- for i in range (nx2 ):
4684- for j in range (ny2 ):
4685- lattice2 [i , j ] = []
4686-
4676+ # accumulation arrays
4677+ lattice1 = [[] for _ in range (1 + nx1 * ny1 )]
4678+ lattice2 = [[] for _ in range (1 + nx2 * ny2 )]
46874679 for i in range (len (x )):
46884680 if bdist [i ]:
4689- if 0 <= ix1 [i ] < nx1 and 0 <= iy1 [i ] < ny1 :
4690- lattice1 [ix1 [i ], iy1 [i ]].append (C [i ])
4681+ lattice1 [i1 [i ]].append (C [i ])
46914682 else :
4692- if 0 <= ix2 [i ] < nx2 and 0 <= iy2 [i ] < ny2 :
4693- lattice2 [ix2 [i ], iy2 [i ]].append (C [i ])
4694-
4695- for i in range (nx1 ):
4696- for j in range (ny1 ):
4697- vals = lattice1 [i , j ]
4698- if len (vals ) > mincnt :
4699- lattice1 [i , j ] = reduce_C_function (vals )
4700- else :
4701- lattice1 [i , j ] = np .nan
4702- for i in range (nx2 ):
4703- for j in range (ny2 ):
4704- vals = lattice2 [i , j ]
4705- if len (vals ) > mincnt :
4706- lattice2 [i , j ] = reduce_C_function (vals )
4707- else :
4708- lattice2 [i , j ] = np .nan
4683+ lattice2 [i2 [i ]].append (C [i ])
4684+ if mincnt is None :
4685+ mincnt = 0
4686+ accum = np .array (
4687+ [reduce_C_function (acc ) if len (acc ) > mincnt else np .nan
4688+ for lattice in [lattice1 , lattice2 ]
4689+ for acc in lattice [1 :]], # [1:] drops out-of-range points.
4690+ float )
47094691
4710- accum = np .concatenate ([lattice1 .astype (float ).ravel (),
4711- lattice2 .astype (float ).ravel ()])
4712- good_idxs = ~ np .isnan (accum )
4692+ good_idxs = ~ np .isnan (accum )
47134693
47144694 offsets = np .zeros ((n , 2 ), float )
47154695 offsets [:nx1 * ny1 , 0 ] = np .repeat (np .arange (nx1 ), ny1 )
@@ -4797,84 +4777,48 @@ def reduce_C_function(C: array) -> float
47974777 return collection
47984778
47994779 # Process marginals
4800- if C is None :
4801- C = np .ones (len (x ))
4780+ for zname , z , zmin , zmax , zscale , nbins in [
4781+ ("x" , x , xmin , xmax , xscale , nx ),
4782+ ("y" , y , ymin , ymax , yscale , 2 * ny ),
4783+ ]:
48024784
4803- def coarse_bin (x , y , bin_edges ):
4804- """
4805- Sort x-values into bins defined by *bin_edges*, then for all the
4806- corresponding y-values in each bin use *reduce_c_function* to
4807- compute the bin value.
4808- """
4809- nbins = len (bin_edges ) - 1
4810- # Sort x-values into bins
4811- bin_idxs = np .searchsorted (bin_edges , x ) - 1
4812- mus = np .zeros (nbins ) * np .nan
4785+ if zscale == "log" :
4786+ bin_edges = np .geomspace (zmin , zmax , nbins + 1 )
4787+ else :
4788+ bin_edges = np .linspace (zmin , zmax , nbins + 1 )
4789+
4790+ verts = np .empty ((nbins , 4 , 2 ))
4791+ verts [:, 0 , 0 ] = verts [:, 1 , 0 ] = bin_edges [:- 1 ]
4792+ verts [:, 2 , 0 ] = verts [:, 3 , 0 ] = bin_edges [1 :]
4793+ verts [:, 0 , 1 ] = verts [:, 3 , 1 ] = .00
4794+ verts [:, 1 , 1 ] = verts [:, 2 , 1 ] = .05
4795+ if zname == "y" :
4796+ verts = verts [:, :, ::- 1 ] # Swap x and y.
4797+
4798+ # Sort z-values into bins defined by bin_edges.
4799+ bin_idxs = np .searchsorted (bin_edges , z ) - 1
4800+ values = np .empty (nbins )
48134801 for i in range (nbins ):
4814- # Get y-values for each bin
4815- yi = y [bin_idxs == i ]
4816- if len (yi ) > 0 :
4817- mus [i ] = reduce_C_function (yi )
4818- return mus
4819-
4820- if xscale == 'log' :
4821- bin_edges = np .geomspace (xmin , xmax , nx + 1 )
4822- else :
4823- bin_edges = np .linspace (xmin , xmax , nx + 1 )
4824- xcoarse = coarse_bin (xorig , C , bin_edges )
4825-
4826- verts , values = [], []
4827- for bin_left , bin_right , val in zip (
4828- bin_edges [:- 1 ], bin_edges [1 :], xcoarse ):
4829- if np .isnan (val ):
4830- continue
4831- verts .append ([(bin_left , 0 ),
4832- (bin_left , 0.05 ),
4833- (bin_right , 0.05 ),
4834- (bin_right , 0 )])
4835- values .append (val )
4836-
4837- values = np .array (values )
4838- trans = self .get_xaxis_transform (which = 'grid' )
4839-
4840- hbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4841-
4842- hbar .set_array (values )
4843- hbar .set_cmap (cmap )
4844- hbar .set_norm (norm )
4845- hbar .set_alpha (alpha )
4846- hbar .update (kwargs )
4847- self .add_collection (hbar , autolim = False )
4848-
4849- if yscale == 'log' :
4850- bin_edges = np .geomspace (ymin , ymax , 2 * ny + 1 )
4851- else :
4852- bin_edges = np .linspace (ymin , ymax , 2 * ny + 1 )
4853- ycoarse = coarse_bin (yorig , C , bin_edges )
4854-
4855- verts , values = [], []
4856- for bin_bottom , bin_top , val in zip (
4857- bin_edges [:- 1 ], bin_edges [1 :], ycoarse ):
4858- if np .isnan (val ):
4859- continue
4860- verts .append ([(0 , bin_bottom ),
4861- (0 , bin_top ),
4862- (0.05 , bin_top ),
4863- (0.05 , bin_bottom )])
4864- values .append (val )
4865-
4866- values = np .array (values )
4867-
4868- trans = self .get_yaxis_transform (which = 'grid' )
4869-
4870- vbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4871- vbar .set_array (values )
4872- vbar .set_cmap (cmap )
4873- vbar .set_norm (norm )
4874- vbar .set_alpha (alpha )
4875- vbar .update (kwargs )
4876- self .add_collection (vbar , autolim = False )
4877-
4802+ # Get C-values for each bin, and compute bin value with
4803+ # reduce_C_function.
4804+ ci = C [bin_idxs == i ]
4805+ values [i ] = reduce_C_function (ci ) if len (ci ) > 0 else np .nan
4806+
4807+ mask = ~ np .isnan (values )
4808+ verts = verts [mask ]
4809+ values = values [mask ]
4810+
4811+ trans = getattr (self , f"get_{ zname } axis_transform" )(which = "grid" )
4812+ bar = mcoll .PolyCollection (
4813+ verts , transform = trans , edgecolors = "face" )
4814+ bar .set_array (values )
4815+ bar .set_cmap (cmap )
4816+ bar .set_norm (norm )
4817+ bar .set_alpha (alpha )
4818+ bar .update (kwargs )
4819+ self .add_collection (bar , autolim = False )
4820+
4821+ hbar , vbar = self .collections [- 2 :]
48784822 collection .hbar = hbar
48794823 collection .vbar = vbar
48804824
0 commit comments