Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ee206a1

Browse files
committed
Refactor hexbin().
- Avoid having to copy() `x` and `y`, by not overwriting the original `x` and `y` variables but instead storing the transformed values in `trfx`/`trfy`. - Directly construct lattice1 and lattice2 as flat arrays (they are flattened at the end anyways), which allows using flat indices: the `C is None` case, becomes a simple `bincount`, the `C is not None` case can use a list-of-lists instead of an object array. - Factor out the x/y marginals handling into a for-loop, which additionally allows inlining coarse_bin. - Make the factor of 2 between nx and ny clearer (in the `for zname...` loop setup). - Construct marginals `verts` in a vectorized fashion.
1 parent 8a8dd90 commit ee206a1

File tree

1 file changed

+100
-160
lines changed

1 file changed

+100
-160
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 100 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -4607,110 +4607,88 @@ def reduce_C_function(C: array) -> float
46074607
nx = gridsize
46084608
ny = int(nx / math.sqrt(3))
46094609
# Count the number of data in each hexagon
4610-
x = np.array(x, float)
4611-
y = np.array(y, float)
4610+
x = np.asarray(x, float)
4611+
y = np.asarray(y, float)
46124612

4613-
if marginals:
4614-
xorig = x.copy()
4615-
yorig = y.copy()
4613+
# Will be log()'d if necessary, and then rescaled.
4614+
tx = x
4615+
ty = y
46164616

46174617
if xscale == 'log':
46184618
if np.any(x <= 0.0):
4619-
raise ValueError("x contains non-positive values, so can not"
4620-
" be log-scaled")
4621-
x = np.log10(x)
4619+
raise ValueError("x contains non-positive values, so can not "
4620+
"be log-scaled")
4621+
tx = np.log10(tx)
46224622
if yscale == 'log':
46234623
if np.any(y <= 0.0):
4624-
raise ValueError("y contains non-positive values, so can not"
4625-
" be log-scaled")
4626-
y = np.log10(y)
4624+
raise ValueError("y contains non-positive values, so can not "
4625+
"be log-scaled")
4626+
ty = np.log10(ty)
46274627
if extent is not None:
46284628
xmin, xmax, ymin, ymax = extent
46294629
else:
4630-
xmin, xmax = (np.min(x), np.max(x)) if len(x) else (0, 1)
4631-
ymin, ymax = (np.min(y), np.max(y)) if len(y) else (0, 1)
4630+
xmin, xmax = (tx.min(), tx.max()) if len(x) else (0, 1)
4631+
ymin, ymax = (ty.min(), ty.max()) if len(y) else (0, 1)
46324632

46334633
# to avoid issues with singular data, expand the min/max pairs
46344634
xmin, xmax = mtransforms.nonsingular(xmin, xmax, expander=0.1)
46354635
ymin, ymax = mtransforms.nonsingular(ymin, ymax, expander=0.1)
46364636

4637+
nx1 = nx + 1
4638+
ny1 = ny + 1
4639+
nx2 = nx
4640+
ny2 = ny
4641+
n = nx1 * ny1 + nx2 * ny2
4642+
46374643
# In the x-direction, the hexagons exactly cover the region from
46384644
# xmin to xmax. Need some padding to avoid roundoff errors.
46394645
padding = 1.e-9 * (xmax - xmin)
46404646
xmin -= padding
46414647
xmax += padding
46424648
sx = (xmax - xmin) / nx
46434649
sy = (ymax - ymin) / ny
4644-
4645-
x = (x - xmin) / sx
4646-
y = (y - ymin) / sy
4647-
ix1 = np.round(x).astype(int)
4648-
iy1 = np.round(y).astype(int)
4649-
ix2 = np.floor(x).astype(int)
4650-
iy2 = np.floor(y).astype(int)
4651-
4652-
nx1 = nx + 1
4653-
ny1 = ny + 1
4654-
nx2 = nx
4655-
ny2 = ny
4656-
n = nx1 * ny1 + nx2 * ny2
4657-
4658-
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
4659-
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
4650+
# Positions in hexagon index coordinates.
4651+
ix = (tx - xmin) / sx
4652+
iy = (ty - ymin) / sy
4653+
ix1 = np.round(ix).astype(int)
4654+
iy1 = np.round(iy).astype(int)
4655+
ix2 = np.floor(ix).astype(int)
4656+
iy2 = np.floor(iy).astype(int)
4657+
# flat indices, plus one so that out-of-range points go to position 0.
4658+
i1 = np.where((0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1),
4659+
ix1 * ny1 + iy1 + 1, 0)
4660+
i2 = np.where((0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2),
4661+
ix2 * ny2 + iy2 + 1, 0)
4662+
4663+
d1 = (ix - ix1) ** 2 + 3.0 * (iy - iy1) ** 2
4664+
d2 = (ix - ix2 - 0.5) ** 2 + 3.0 * (iy - iy2 - 0.5) ** 2
46604665
bdist = (d1 < d2)
4661-
if C is None:
4662-
lattice1 = np.zeros((nx1, ny1))
4663-
lattice2 = np.zeros((nx2, ny2))
4664-
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
4665-
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
4666-
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
4667-
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
4668-
if mincnt is not None:
4669-
lattice1[lattice1 < mincnt] = np.nan
4670-
lattice2[lattice2 < mincnt] = np.nan
4671-
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
4672-
good_idxs = ~np.isnan(accum)
46734666

4667+
if C is None: # [1:] drops out-of-range points.
4668+
counts1 = np.bincount(i1[bdist], minlength=1 + nx1 * ny1)[1:]
4669+
counts2 = np.bincount(i2[~bdist], minlength=1 + nx2 * ny2)[1:]
4670+
accum = np.concatenate([counts1, counts2]).astype(float)
4671+
if mincnt is not None:
4672+
accum[accum < mincnt] = np.nan
4673+
C = np.ones(len(x))
46744674
else:
4675-
if mincnt is None:
4676-
mincnt = 0
4677-
4678-
# create accumulation arrays
4679-
lattice1 = np.empty((nx1, ny1), dtype=object)
4680-
for i in range(nx1):
4681-
for j in range(ny1):
4682-
lattice1[i, j] = []
4683-
lattice2 = np.empty((nx2, ny2), dtype=object)
4684-
for i in range(nx2):
4685-
for j in range(ny2):
4686-
lattice2[i, j] = []
4687-
4675+
# store the C values in a list per hexagon index
4676+
Cs_at_i1 = [[] for _ in range(1 + nx1 * ny1)]
4677+
Cs_at_i2 = [[] for _ in range(1 + nx2 * ny2)]
46884678
for i in range(len(x)):
46894679
if bdist[i]:
4690-
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
4691-
lattice1[ix1[i], iy1[i]].append(C[i])
4680+
Cs_at_i1[i1[i]].append(C[i])
46924681
else:
4693-
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
4694-
lattice2[ix2[i], iy2[i]].append(C[i])
4695-
4696-
for i in range(nx1):
4697-
for j in range(ny1):
4698-
vals = lattice1[i, j]
4699-
if len(vals) > mincnt:
4700-
lattice1[i, j] = reduce_C_function(vals)
4701-
else:
4702-
lattice1[i, j] = np.nan
4703-
for i in range(nx2):
4704-
for j in range(ny2):
4705-
vals = lattice2[i, j]
4706-
if len(vals) > mincnt:
4707-
lattice2[i, j] = reduce_C_function(vals)
4708-
else:
4709-
lattice2[i, j] = np.nan
4682+
Cs_at_i2[i2[i]].append(C[i])
4683+
if mincnt is None:
4684+
mincnt = 0
4685+
accum = np.array(
4686+
[reduce_C_function(acc) if len(acc) > mincnt else np.nan
4687+
for Cs_at_i in [Cs_at_i1, Cs_at_i2]
4688+
for acc in Cs_at_i[1:]], # [1:] drops out-of-range points.
4689+
float)
47104690

4711-
accum = np.concatenate([lattice1.astype(float).ravel(),
4712-
lattice2.astype(float).ravel()])
4713-
good_idxs = ~np.isnan(accum)
4691+
good_idxs = ~np.isnan(accum)
47144692

47154693
offsets = np.zeros((n, 2), float)
47164694
offsets[:nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
@@ -4767,8 +4745,7 @@ def reduce_C_function(C: array) -> float
47674745
vmin = vmax = None
47684746
bins = None
47694747

4770-
# autoscale the norm with current accum values if it hasn't
4771-
# been set
4748+
# autoscale the norm with current accum values if it hasn't been set
47724749
if norm is not None:
47734750
if norm.vmin is None and norm.vmax is None:
47744751
norm.autoscale(accum)
@@ -4798,92 +4775,55 @@ def reduce_C_function(C: array) -> float
47984775
return collection
47994776

48004777
# Process marginals
4801-
if C is None:
4802-
C = np.ones(len(x))
4778+
bars = []
4779+
for zname, z, zmin, zmax, zscale, nbins in [
4780+
("x", x, xmin, xmax, xscale, nx),
4781+
("y", y, ymin, ymax, yscale, 2 * ny),
4782+
]:
48034783

4804-
def coarse_bin(x, y, bin_edges):
4805-
"""
4806-
Sort x-values into bins defined by *bin_edges*, then for all the
4807-
corresponding y-values in each bin use *reduce_c_function* to
4808-
compute the bin value.
4809-
"""
4810-
nbins = len(bin_edges) - 1
4811-
# Sort x-values into bins
4812-
bin_idxs = np.searchsorted(bin_edges, x) - 1
4813-
mus = np.zeros(nbins) * np.nan
4784+
if zscale == "log":
4785+
bin_edges = np.geomspace(zmin, zmax, nbins + 1)
4786+
else:
4787+
bin_edges = np.linspace(zmin, zmax, nbins + 1)
4788+
4789+
verts = np.empty((nbins, 4, 2))
4790+
verts[:, 0, 0] = verts[:, 1, 0] = bin_edges[:-1]
4791+
verts[:, 2, 0] = verts[:, 3, 0] = bin_edges[1:]
4792+
verts[:, 0, 1] = verts[:, 3, 1] = .00
4793+
verts[:, 1, 1] = verts[:, 2, 1] = .05
4794+
if zname == "y":
4795+
verts = verts[:, :, ::-1] # Swap x and y.
4796+
4797+
# Sort z-values into bins defined by bin_edges.
4798+
bin_idxs = np.searchsorted(bin_edges, z) - 1
4799+
values = np.empty(nbins)
48144800
for i in range(nbins):
4815-
# Get y-values for each bin
4816-
yi = y[bin_idxs == i]
4817-
if len(yi) > 0:
4818-
mus[i] = reduce_C_function(yi)
4819-
return mus
4820-
4821-
if xscale == 'log':
4822-
bin_edges = np.geomspace(xmin, xmax, nx + 1)
4823-
else:
4824-
bin_edges = np.linspace(xmin, xmax, nx + 1)
4825-
xcoarse = coarse_bin(xorig, C, bin_edges)
4826-
4827-
verts, values = [], []
4828-
for bin_left, bin_right, val in zip(
4829-
bin_edges[:-1], bin_edges[1:], xcoarse):
4830-
if np.isnan(val):
4831-
continue
4832-
verts.append([(bin_left, 0),
4833-
(bin_left, 0.05),
4834-
(bin_right, 0.05),
4835-
(bin_right, 0)])
4836-
values.append(val)
4837-
4838-
values = np.array(values)
4839-
trans = self.get_xaxis_transform(which='grid')
4840-
4841-
hbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4842-
4843-
hbar.set_array(values)
4844-
hbar.set_cmap(cmap)
4845-
hbar.set_norm(norm)
4846-
hbar.set_alpha(alpha)
4847-
hbar.update(kwargs)
4848-
self.add_collection(hbar, autolim=False)
4849-
4850-
if yscale == 'log':
4851-
bin_edges = np.geomspace(ymin, ymax, 2 * ny + 1)
4852-
else:
4853-
bin_edges = np.linspace(ymin, ymax, 2 * ny + 1)
4854-
ycoarse = coarse_bin(yorig, C, bin_edges)
4855-
4856-
verts, values = [], []
4857-
for bin_bottom, bin_top, val in zip(
4858-
bin_edges[:-1], bin_edges[1:], ycoarse):
4859-
if np.isnan(val):
4860-
continue
4861-
verts.append([(0, bin_bottom),
4862-
(0, bin_top),
4863-
(0.05, bin_top),
4864-
(0.05, bin_bottom)])
4865-
values.append(val)
4866-
4867-
values = np.array(values)
4868-
4869-
trans = self.get_yaxis_transform(which='grid')
4870-
4871-
vbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4872-
vbar.set_array(values)
4873-
vbar.set_cmap(cmap)
4874-
vbar.set_norm(norm)
4875-
vbar.set_alpha(alpha)
4876-
vbar.update(kwargs)
4877-
self.add_collection(vbar, autolim=False)
4878-
4879-
collection.hbar = hbar
4880-
collection.vbar = vbar
4801+
# Get C-values for each bin, and compute bin value with
4802+
# reduce_C_function.
4803+
ci = C[bin_idxs == i]
4804+
values[i] = reduce_C_function(ci) if len(ci) > 0 else np.nan
4805+
4806+
mask = ~np.isnan(values)
4807+
verts = verts[mask]
4808+
values = values[mask]
4809+
4810+
trans = getattr(self, f"get_{zname}axis_transform")(which="grid")
4811+
bar = mcoll.PolyCollection(
4812+
verts, transform=trans, edgecolors="face")
4813+
bar.set_array(values)
4814+
bar.set_cmap(cmap)
4815+
bar.set_norm(norm)
4816+
bar.set_alpha(alpha)
4817+
bar.update(kwargs)
4818+
bars.append(self.add_collection(bar, autolim=False))
4819+
4820+
collection.hbar, collection.vbar = bars
48814821

48824822
def on_changed(collection):
4883-
hbar.set_cmap(collection.get_cmap())
4884-
hbar.set_clim(collection.get_clim())
4885-
vbar.set_cmap(collection.get_cmap())
4886-
vbar.set_clim(collection.get_clim())
4823+
collection.hbar.set_cmap(collection.get_cmap())
4824+
collection.hbar.set_cmap(collection.get_cmap())
4825+
collection.vbar.set_clim(collection.get_clim())
4826+
collection.vbar.set_clim(collection.get_clim())
48874827

48884828
collection.callbacks.connect('changed', on_changed)
48894829

0 commit comments

Comments
 (0)