Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a896c42

Browse files
QuLogicmeeseeksmachine
authored andcommitted
Backport PR #19964: FIX: add subplot_mosaic axes in the order the user gave them to us
1 parent a14a8fe commit a896c42

File tree

2 files changed

+83
-24
lines changed

2 files changed

+83
-24
lines changed

lib/matplotlib/figure.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,11 +1791,12 @@ def _identify_keys_and_nested(layout):
17911791
17921792
Returns
17931793
-------
1794-
unique_ids : set
1794+
unique_ids : tuple
17951795
The unique non-sub layout entries in this layout
17961796
nested : dict[tuple[int, int]], 2D object array
17971797
"""
1798-
unique_ids = set()
1798+
# make sure we preserve the user supplied order
1799+
unique_ids = cbook._OrderedSet()
17991800
nested = {}
18001801
for j, row in enumerate(layout):
18011802
for k, v in enumerate(row):
@@ -1806,7 +1807,7 @@ def _identify_keys_and_nested(layout):
18061807
else:
18071808
unique_ids.add(v)
18081809

1809-
return unique_ids, nested
1810+
return tuple(unique_ids), nested
18101811

18111812
def _do_layout(gs, layout, unique_ids, nested):
18121813
"""
@@ -1817,7 +1818,7 @@ def _do_layout(gs, layout, unique_ids, nested):
18171818
gs : GridSpec
18181819
layout : 2D object array
18191820
The input converted to a 2D numpy array for this level.
1820-
unique_ids : set
1821+
unique_ids : tuple
18211822
The identified scalar labels at this level of nesting.
18221823
nested : dict[tuple[int, int]], 2D object array
18231824
The identified nested layouts, if any.
@@ -1830,38 +1831,74 @@ def _do_layout(gs, layout, unique_ids, nested):
18301831
rows, cols = layout.shape
18311832
output = dict()
18321833

1833-
# create the Axes at this level of nesting
1834+
# we need to merge together the Axes at this level and the axes
1835+
# in the (recursively) nested sub-layouts so that we can add
1836+
# them to the figure in the "natural" order if you were to
1837+
# ravel in c-order all of the Axes that will be created
1838+
#
1839+
# This will stash the upper left index of each object (axes or
1840+
# nested layout) at this level
1841+
this_level = dict()
1842+
1843+
# go through the unique keys,
18341844
for name in unique_ids:
1845+
# sort out where each axes starts/ends
18351846
indx = np.argwhere(layout == name)
18361847
start_row, start_col = np.min(indx, axis=0)
18371848
end_row, end_col = np.max(indx, axis=0) + 1
1849+
# and construct the slice object
18381850
slc = (slice(start_row, end_row), slice(start_col, end_col))
1839-
1851+
# some light error checking
18401852
if (layout[slc] != name).any():
18411853
raise ValueError(
18421854
f"While trying to layout\n{layout!r}\n"
18431855
f"we found that the label {name!r} specifies a "
18441856
"non-rectangular or non-contiguous area.")
1857+
# and stash this slice for later
1858+
this_level[(start_row, start_col)] = (name, slc, 'axes')
18451859

1846-
ax = self.add_subplot(
1847-
gs[slc], **{'label': str(name), **subplot_kw}
1848-
)
1849-
output[name] = ax
1850-
1851-
# do any sub-layouts
1860+
# do the same thing for the nested layouts (simpler because these
1861+
# can not be spans yet!)
18521862
for (j, k), nested_layout in nested.items():
1853-
rows, cols = nested_layout.shape
1854-
nested_output = _do_layout(
1855-
gs[j, k].subgridspec(rows, cols, **gridspec_kw),
1856-
nested_layout,
1857-
*_identify_keys_and_nested(nested_layout)
1858-
)
1859-
overlap = set(output) & set(nested_output)
1860-
if overlap:
1861-
raise ValueError(f"There are duplicate keys {overlap} "
1862-
f"between the outer layout\n{layout!r}\n"
1863-
f"and the nested layout\n{nested_layout}")
1864-
output.update(nested_output)
1863+
this_level[(j, k)] = (None, nested_layout, 'nested')
1864+
1865+
# now go through the things in this level and add them
1866+
# in order left-to-right top-to-bottom
1867+
for key in sorted(this_level):
1868+
name, arg, method = this_level[key]
1869+
# we are doing some hokey function dispatch here based
1870+
# on the 'method' string stashed above to sort out if this
1871+
# element is an axes or a nested layout.
1872+
if method == 'axes':
1873+
slc = arg
1874+
# add a single axes
1875+
if name in output:
1876+
raise ValueError(f"There are duplicate keys {name} "
1877+
f"in the layout\n{layout!r}")
1878+
ax = self.add_subplot(
1879+
gs[slc], **{'label': str(name), **subplot_kw}
1880+
)
1881+
output[name] = ax
1882+
elif method == 'nested':
1883+
nested_layout = arg
1884+
j, k = key
1885+
# recursively add the nested layout
1886+
rows, cols = nested_layout.shape
1887+
nested_output = _do_layout(
1888+
gs[j, k].subgridspec(rows, cols, **gridspec_kw),
1889+
nested_layout,
1890+
*_identify_keys_and_nested(nested_layout)
1891+
)
1892+
overlap = set(output) & set(nested_output)
1893+
if overlap:
1894+
raise ValueError(
1895+
f"There are duplicate keys {overlap} "
1896+
f"between the outer layout\n{layout!r}\n"
1897+
f"and the nested layout\n{nested_layout}"
1898+
)
1899+
output.update(nested_output)
1900+
else:
1901+
raise RuntimeError("This should never happen")
18651902
return output
18661903

18671904
layout = _make_array(layout)

lib/matplotlib/tests/test_figure.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,28 @@ def test_hashable_keys(self, fig_test, fig_ref):
861861
fig_test.subplot_mosaic([[object(), object()]])
862862
fig_ref.subplot_mosaic([["A", "B"]])
863863

864+
@pytest.mark.parametrize('str_pattern',
865+
['abc', 'cab', 'bca', 'cba', 'acb', 'bac'])
866+
def test_user_order(self, str_pattern):
867+
fig = plt.figure()
868+
ax_dict = fig.subplot_mosaic(str_pattern)
869+
assert list(str_pattern) == list(ax_dict)
870+
assert list(fig.axes) == list(ax_dict.values())
871+
872+
def test_nested_user_order(self):
873+
layout = [
874+
["A", [["B", "C"],
875+
["D", "E"]]],
876+
["F", "G"],
877+
[".", [["H", [["I"],
878+
["."]]]]]
879+
]
880+
881+
fig = plt.figure()
882+
ax_dict = fig.subplot_mosaic(layout)
883+
assert list(ax_dict) == list("ABCDEFGHI")
884+
assert list(fig.axes) == list(ax_dict.values())
885+
864886

865887
def test_reused_gridspec():
866888
"""Test that these all use the same gridspec"""

0 commit comments

Comments
 (0)