Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a2d4d9a

Browse files
committed
ENH/WIP: make it work recursively
1 parent 63e5ba3 commit a2d4d9a

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

lib/matplotlib/figure.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from matplotlib import rcParams
1919
from matplotlib import docstring, projections
2020
from matplotlib import __version__ as _mpl_version
21-
import matplotlib.gridspec as gridspec
2221

2322
import matplotlib.artist as martist
2423
from matplotlib.artist import Artist, allow_rasterization
@@ -30,7 +29,7 @@
3029

3130
from matplotlib.axes import Axes, SubplotBase, subplot_class_factory
3231
from matplotlib.blocking_input import BlockingMouseInput, BlockingKeyMouseInput
33-
from matplotlib.gridspec import GridSpec
32+
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
3433
import matplotlib.legend as mlegend
3534
from matplotlib.patches import Rectangle
3635
from matplotlib.text import Text
@@ -1617,25 +1616,47 @@ def build_grid(self, layout, subplot_kw=None, gridspec_kw=None):
16171616
"""
16181617
subplot_kw = subplot_kw or {}
16191618
gridspec_kw = gridspec_kw or {}
1619+
1620+
def _process_layout(layout):
1621+
unique_ids = set()
1622+
nested = {}
1623+
for j, row in enumerate(layout):
1624+
for k, v in enumerate(row):
1625+
if isinstance(v, str):
1626+
unique_ids.add(v)
1627+
else:
1628+
nested[(j, k)] = v
1629+
1630+
return unique_ids, nested
1631+
1632+
def _do_layout(gs, layout, unique_ids, nested):
1633+
rows, cols = layout.shape
1634+
covered = np.zeros((rows, cols), dtype=bool)
1635+
output = dict()
1636+
for name in unique_ids:
1637+
indx = np.stack(np.where(layout == name)).T
1638+
start_row, start_col = np.min(indx, axis=0)
1639+
end_row, end_col = np.max(indx, axis=0) + 1
1640+
slc = (slice(start_row, end_row), slice(start_col, end_col))
1641+
if not np.all(covered[slc] == False):
1642+
raise ValueError
1643+
covered[slc] = True
1644+
output[name] = self.add_subplot(gs[slc], **subplot_kw)
1645+
1646+
for (j, k), layout in nested.items():
1647+
layout = np.asarray(layout)
1648+
rows, cols = layout.shape
1649+
gs_n = GridSpecFromSubplotSpec(rows, cols, gs[j, k])
1650+
nested_output = _do_layout(gs_n, layout, *_process_layout(layout))
1651+
1652+
output.update(nested_output)
1653+
return output
1654+
16201655
layout = np.asarray(layout)
16211656
rows, cols = layout.shape
1622-
unique_ids = set(layout.ravel())
1623-
1624-
gs = gridspec.GridSpec(rows, cols, figure=self, **gridspec_kw)
1625-
1626-
covered = np.zeros((rows, cols), dtype=bool)
1627-
output = dict()
1628-
for name in unique_ids:
1629-
indx = np.stack(np.where(layout == name)).T
1630-
start_row, start_col = np.min(indx, axis=0)
1631-
end_row, end_col = np.max(indx, axis=0) + 1
1632-
slc = (slice(start_row, end_row), slice(start_col, end_col))
1633-
if not np.all(covered[slc] == False):
1634-
raise ValueError
1635-
covered[slc] = True
1636-
output[name] = self.add_subplot(gs[slc], **subplot_kw)
1637-
1638-
return output
1657+
gs = GridSpec(rows, cols, figure=self, **gridspec_kw)
1658+
return _do_layout(gs, layout, *_process_layout(layout))
1659+
16391660

16401661
def delaxes(self, ax):
16411662
"""

lib/matplotlib/tests/test_build_grid.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@check_figures_equal()
6-
def test_label_loc_vertical(fig_test, fig_ref):
6+
def test_basic(fig_test, fig_ref):
77
x = [["A", "A", "B"], ["C", "D", "B"]]
88
grid_axes = fig_test.build_grid(x)
99

@@ -22,3 +22,36 @@ def test_label_loc_vertical(fig_test, fig_ref):
2222

2323
axD = fig_ref.add_subplot(gs[1, 1])
2424
axD.set_title("D")
25+
26+
27+
@check_figures_equal(tol=0.005)
28+
def test_recursive(fig_test, fig_ref):
29+
x = [["A", "B", "B"], ["C", "C", "D"]]
30+
31+
y = [["F"], [x]]
32+
33+
grid_axes = fig_test.build_grid(y)
34+
35+
for k, ax in grid_axes.items():
36+
ax.set_title(k)
37+
38+
gs = gridspec.GridSpec(2, 1, figure=fig_ref)
39+
axF = fig_ref.add_subplot(gs[0, 0])
40+
axF.set_title("F")
41+
42+
gs_n = gridspec.GridSpecFromSubplotSpec(2, 3, gs[1, 0])
43+
44+
axA = fig_ref.add_subplot(gs_n[0, 0])
45+
axA.set_title("A")
46+
47+
axB = fig_ref.add_subplot(gs_n[0, 1:])
48+
axB.set_title("B")
49+
50+
axC = fig_ref.add_subplot(gs_n[1, :2])
51+
axC.set_title("C")
52+
53+
axD = fig_ref.add_subplot(gs_n[1, 2])
54+
axD.set_title("D")
55+
56+
fig_test.tight_layout()
57+
fig_ref.tight_layout()

0 commit comments

Comments
 (0)