Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ed05335

Browse files
committed
Refactor hexbin().
- Avoid having to copy() `x` and `y`, by not overwriting the original `x` and `y` variables but instead storing the transformed values in `trfx`/`trfy`. - Directly construct lattice1 and lattice2 as flat arrays (they are flattened at the end anyways), which allows using flat indices: the `C is None` case, becomes a simple `bincount`, the `C is not None` case can use a list-of-lists instead of an object array. - Factor out the x/y marginals handling into a for-loop, which additionally allows inlining coarse_bin. - Make the factor of 2 between nx and ny clearer (in the `for zname...` loop setup). - Construct marginals `verts` in a vectorized fashion.
1 parent 97ba9d4 commit ed05335

File tree

1 file changed

+95
-152
lines changed

1 file changed

+95
-152
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 95 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -4606,110 +4606,90 @@ def reduce_C_function(C: array) -> float
46064606
nx = gridsize
46074607
ny = int(nx / math.sqrt(3))
46084608
# Count the number of data in each hexagon
4609-
x = np.array(x, float)
4610-
y = np.array(y, float)
4609+
x = np.asarray(x, float)
4610+
y = np.asarray(y, float)
46114611

4612-
if marginals:
4613-
xorig = x.copy()
4614-
yorig = y.copy()
4612+
# Will be log()'d if necessary, and then rescaled.
4613+
tx = x
4614+
ty = y
46154615

46164616
if xscale == 'log':
46174617
if np.any(x <= 0.0):
4618-
raise ValueError("x contains non-positive values, so can not"
4619-
" be log-scaled")
4620-
x = np.log10(x)
4618+
raise ValueError("x contains non-positive values, so can not "
4619+
"be log-scaled")
4620+
tx = np.log10(tx)
46214621
if yscale == 'log':
46224622
if np.any(y <= 0.0):
4623-
raise ValueError("y contains non-positive values, so can not"
4624-
" be log-scaled")
4625-
y = np.log10(y)
4623+
raise ValueError("y contains non-positive values, so can not "
4624+
"be log-scaled")
4625+
ty = np.log10(ty)
46264626
if extent is not None:
46274627
xmin, xmax, ymin, ymax = extent
46284628
else:
4629-
xmin, xmax = (np.min(x), np.max(x)) if len(x) else (0, 1)
4630-
ymin, ymax = (np.min(y), np.max(y)) if len(y) else (0, 1)
4629+
xmin, xmax = (tx.min(), tx.max()) if len(x) else (0, 1)
4630+
ymin, ymax = (ty.min(), ty.max()) if len(y) else (0, 1)
46314631

46324632
# to avoid issues with singular data, expand the min/max pairs
46334633
xmin, xmax = mtransforms.nonsingular(xmin, xmax, expander=0.1)
46344634
ymin, ymax = mtransforms.nonsingular(ymin, ymax, expander=0.1)
46354635

4636+
nx1 = nx + 1
4637+
ny1 = ny + 1
4638+
nx2 = nx
4639+
ny2 = ny
4640+
n = nx1 * ny1 + nx2 * ny2
4641+
46364642
# In the x-direction, the hexagons exactly cover the region from
46374643
# xmin to xmax. Need some padding to avoid roundoff errors.
46384644
padding = 1.e-9 * (xmax - xmin)
46394645
xmin -= padding
46404646
xmax += padding
46414647
sx = (xmax - xmin) / nx
46424648
sy = (ymax - ymin) / ny
4643-
4644-
x = (x - xmin) / sx
4645-
y = (y - ymin) / sy
4646-
ix1 = np.round(x).astype(int)
4647-
iy1 = np.round(y).astype(int)
4648-
ix2 = np.floor(x).astype(int)
4649-
iy2 = np.floor(y).astype(int)
4650-
4651-
nx1 = nx + 1
4652-
ny1 = ny + 1
4653-
nx2 = nx
4654-
ny2 = ny
4655-
n = nx1 * ny1 + nx2 * ny2
4656-
4657-
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
4658-
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
4649+
# Positions in hexagon index coordinates.
4650+
ix = (tx - xmin) / sx
4651+
iy = (ty - ymin) / sy
4652+
ix1 = np.round(ix).astype(int)
4653+
iy1 = np.round(iy).astype(int)
4654+
ix2 = np.floor(ix).astype(int)
4655+
iy2 = np.floor(iy).astype(int)
4656+
# flat indices, plus one so that out-of-range points go to position 0.
4657+
i1 = np.where((0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1),
4658+
ix1 * ny1 + iy1 + 1, 0)
4659+
i2 = np.where((0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2),
4660+
ix2 * ny2 + iy2 + 1, 0)
4661+
4662+
d1 = (ix - ix1) ** 2 + 3.0 * (iy - iy1) ** 2
4663+
d2 = (ix - ix2 - 0.5) ** 2 + 3.0 * (iy - iy2 - 0.5) ** 2
46594664
bdist = (d1 < d2)
4665+
46604666
if C is None:
4661-
lattice1 = np.zeros((nx1, ny1))
4662-
lattice2 = np.zeros((nx2, ny2))
4663-
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
4664-
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
4665-
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
4666-
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
4667+
lattice1 = np.bincount(i1[bdist], minlength=1 + nx1 * ny1)
4668+
lattice2 = np.bincount(i2[~bdist], minlength=1 + nx2 * ny2)
4669+
accum = np.concatenate( # [1:] drops out-of-range points.
4670+
[lattice1.ravel()[1:], lattice2.ravel()[1:]]).astype(float)
46674671
if mincnt is not None:
4668-
lattice1[lattice1 < mincnt] = np.nan
4669-
lattice2[lattice2 < mincnt] = np.nan
4670-
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
4671-
good_idxs = ~np.isnan(accum)
4672+
accum[accum < mincnt] = np.nan
4673+
C = np.ones(len(x))
46724674

46734675
else:
4674-
if mincnt is None:
4675-
mincnt = 0
4676-
4677-
# create accumulation arrays
4678-
lattice1 = np.empty((nx1, ny1), dtype=object)
4679-
for i in range(nx1):
4680-
for j in range(ny1):
4681-
lattice1[i, j] = []
4682-
lattice2 = np.empty((nx2, ny2), dtype=object)
4683-
for i in range(nx2):
4684-
for j in range(ny2):
4685-
lattice2[i, j] = []
4686-
4676+
# accumulation arrays
4677+
lattice1 = [[] for _ in range(1 + nx1 * ny1)]
4678+
lattice2 = [[] for _ in range(1 + nx2 * ny2)]
46874679
for i in range(len(x)):
46884680
if bdist[i]:
4689-
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
4690-
lattice1[ix1[i], iy1[i]].append(C[i])
4681+
lattice1[i1[i]].append(C[i])
46914682
else:
4692-
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
4693-
lattice2[ix2[i], iy2[i]].append(C[i])
4694-
4695-
for i in range(nx1):
4696-
for j in range(ny1):
4697-
vals = lattice1[i, j]
4698-
if len(vals) > mincnt:
4699-
lattice1[i, j] = reduce_C_function(vals)
4700-
else:
4701-
lattice1[i, j] = np.nan
4702-
for i in range(nx2):
4703-
for j in range(ny2):
4704-
vals = lattice2[i, j]
4705-
if len(vals) > mincnt:
4706-
lattice2[i, j] = reduce_C_function(vals)
4707-
else:
4708-
lattice2[i, j] = np.nan
4683+
lattice2[i2[i]].append(C[i])
4684+
if mincnt is None:
4685+
mincnt = 0
4686+
accum = np.array(
4687+
[reduce_C_function(acc) if len(acc) > mincnt else np.nan
4688+
for lattice in [lattice1, lattice2]
4689+
for acc in lattice[1:]], # [1:] drops out-of-range points.
4690+
float)
47094691

4710-
accum = np.concatenate([lattice1.astype(float).ravel(),
4711-
lattice2.astype(float).ravel()])
4712-
good_idxs = ~np.isnan(accum)
4692+
good_idxs = ~np.isnan(accum)
47134693

47144694
offsets = np.zeros((n, 2), float)
47154695
offsets[:nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
@@ -4766,8 +4746,7 @@ def reduce_C_function(C: array) -> float
47664746
vmin = vmax = None
47674747
bins = None
47684748

4769-
# autoscale the norm with current accum values if it hasn't
4770-
# been set
4749+
# autoscale the norm with current accum values if it hasn't been set
47714750
if norm is not None:
47724751
if norm.vmin is None and norm.vmax is None:
47734752
norm.autoscale(accum)
@@ -4797,84 +4776,48 @@ def reduce_C_function(C: array) -> float
47974776
return collection
47984777

47994778
# Process marginals
4800-
if C is None:
4801-
C = np.ones(len(x))
4779+
for zname, z, zmin, zmax, zscale, nbins in [
4780+
("x", x, xmin, xmax, xscale, nx),
4781+
("y", y, ymin, ymax, yscale, 2 * ny),
4782+
]:
48024783

4803-
def coarse_bin(x, y, bin_edges):
4804-
"""
4805-
Sort x-values into bins defined by *bin_edges*, then for all the
4806-
corresponding y-values in each bin use *reduce_c_function* to
4807-
compute the bin value.
4808-
"""
4809-
nbins = len(bin_edges) - 1
4810-
# Sort x-values into bins
4811-
bin_idxs = np.searchsorted(bin_edges, x) - 1
4812-
mus = np.zeros(nbins) * np.nan
4784+
if zscale == "log":
4785+
bin_edges = np.geomspace(zmin, zmax, nbins + 1)
4786+
else:
4787+
bin_edges = np.linspace(zmin, zmax, nbins + 1)
4788+
4789+
verts = np.empty((nbins, 4, 2))
4790+
verts[:, 0, 0] = verts[:, 1, 0] = bin_edges[:-1]
4791+
verts[:, 2, 0] = verts[:, 3, 0] = bin_edges[1:]
4792+
verts[:, 0, 1] = verts[:, 3, 1] = .00
4793+
verts[:, 1, 1] = verts[:, 2, 1] = .05
4794+
if zname == "y":
4795+
verts = verts[:, :, ::-1] # Swap x and y.
4796+
4797+
# Sort z-values into bins defined by bin_edges.
4798+
bin_idxs = np.searchsorted(bin_edges, z) - 1
4799+
values = np.empty(nbins)
48134800
for i in range(nbins):
4814-
# Get y-values for each bin
4815-
yi = y[bin_idxs == i]
4816-
if len(yi) > 0:
4817-
mus[i] = reduce_C_function(yi)
4818-
return mus
4819-
4820-
if xscale == 'log':
4821-
bin_edges = np.geomspace(xmin, xmax, nx + 1)
4822-
else:
4823-
bin_edges = np.linspace(xmin, xmax, nx + 1)
4824-
xcoarse = coarse_bin(xorig, C, bin_edges)
4825-
4826-
verts, values = [], []
4827-
for bin_left, bin_right, val in zip(
4828-
bin_edges[:-1], bin_edges[1:], xcoarse):
4829-
if np.isnan(val):
4830-
continue
4831-
verts.append([(bin_left, 0),
4832-
(bin_left, 0.05),
4833-
(bin_right, 0.05),
4834-
(bin_right, 0)])
4835-
values.append(val)
4836-
4837-
values = np.array(values)
4838-
trans = self.get_xaxis_transform(which='grid')
4839-
4840-
hbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4841-
4842-
hbar.set_array(values)
4843-
hbar.set_cmap(cmap)
4844-
hbar.set_norm(norm)
4845-
hbar.set_alpha(alpha)
4846-
hbar.update(kwargs)
4847-
self.add_collection(hbar, autolim=False)
4848-
4849-
if yscale == 'log':
4850-
bin_edges = np.geomspace(ymin, ymax, 2 * ny + 1)
4851-
else:
4852-
bin_edges = np.linspace(ymin, ymax, 2 * ny + 1)
4853-
ycoarse = coarse_bin(yorig, C, bin_edges)
4854-
4855-
verts, values = [], []
4856-
for bin_bottom, bin_top, val in zip(
4857-
bin_edges[:-1], bin_edges[1:], ycoarse):
4858-
if np.isnan(val):
4859-
continue
4860-
verts.append([(0, bin_bottom),
4861-
(0, bin_top),
4862-
(0.05, bin_top),
4863-
(0.05, bin_bottom)])
4864-
values.append(val)
4865-
4866-
values = np.array(values)
4867-
4868-
trans = self.get_yaxis_transform(which='grid')
4869-
4870-
vbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4871-
vbar.set_array(values)
4872-
vbar.set_cmap(cmap)
4873-
vbar.set_norm(norm)
4874-
vbar.set_alpha(alpha)
4875-
vbar.update(kwargs)
4876-
self.add_collection(vbar, autolim=False)
4877-
4801+
# Get C-values for each bin, and compute bin value with
4802+
# reduce_C_function.
4803+
ci = C[bin_idxs == i]
4804+
values[i] = reduce_C_function(ci) if len(ci) > 0 else np.nan
4805+
4806+
mask = ~np.isnan(values)
4807+
verts = verts[mask]
4808+
values = values[mask]
4809+
4810+
trans = getattr(self, f"get_{zname}axis_transform")(which="grid")
4811+
bar = mcoll.PolyCollection(
4812+
verts, transform=trans, edgecolors="face")
4813+
bar.set_array(values)
4814+
bar.set_cmap(cmap)
4815+
bar.set_norm(norm)
4816+
bar.set_alpha(alpha)
4817+
bar.update(kwargs)
4818+
self.add_collection(bar, autolim=False)
4819+
4820+
hbar, vbar = self.collections[-2:]
48784821
collection.hbar = hbar
48794822
collection.vbar = vbar
48804823

0 commit comments

Comments
 (0)