diff --git a/lib/matplotlib/__init__.py b/lib/matplotlib/__init__.py index 9193d179e57c..071ffe94b0d6 100644 --- a/lib/matplotlib/__init__.py +++ b/lib/matplotlib/__init__.py @@ -1442,6 +1442,7 @@ def tk_window_focus(): 'matplotlib.tests.test_widgets', 'mpl_toolkits.tests.test_mplot3d', 'mpl_toolkits.tests.test_axes_grid1', + 'mpl_toolkits.tests.test_axes_grid', ] diff --git a/lib/mpl_toolkits/axes_grid1/axes_grid.py b/lib/mpl_toolkits/axes_grid1/axes_grid.py index 1b996acca04f..6af68528be38 100644 --- a/lib/mpl_toolkits/axes_grid1/axes_grid.py +++ b/lib/mpl_toolkits/axes_grid1/axes_grid.py @@ -640,6 +640,13 @@ def __init__(self, fig, if self._colorbar_mode == "single": for ax in self.axes_all: ax.cax = self.cbar_axes[0] + elif self._colorbar_mode == "edge": + for index, ax in enumerate(self.axes_all): + col, row = self._get_col_row(index) + if self._colorbar_location in ("left", "right"): + ax.cax = self.cbar_axes[row] + else: + ax.cax = self.cbar_axes[col] else: for ax, cax in zip(self.axes_all, self.cbar_axes): ax.cax = cax diff --git a/lib/mpl_toolkits/tests/baseline_images/test_axes_grid/imagegrid_cbar_mode.png b/lib/mpl_toolkits/tests/baseline_images/test_axes_grid/imagegrid_cbar_mode.png new file mode 100644 index 000000000000..09dfd7ddbbaa Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_axes_grid/imagegrid_cbar_mode.png differ diff --git a/lib/mpl_toolkits/tests/test_axes_grid.py b/lib/mpl_toolkits/tests/test_axes_grid.py new file mode 100644 index 000000000000..80db9e3d1314 --- /dev/null +++ b/lib/mpl_toolkits/tests/test_axes_grid.py @@ -0,0 +1,42 @@ + +from matplotlib.testing.decorators import image_comparison +from mpl_toolkits.axes_grid1 import ImageGrid +import numpy as np +import matplotlib.pyplot as plt + + +@image_comparison(baseline_images=['imagegrid_cbar_mode'], + extensions=['png'], + remove_text=True) +def test_imagegrid_cbar_mode_edge(): + X, Y = np.meshgrid(np.linspace(0, 6, 30), np.linspace(0, 6, 30)) + arr = np.sin(X) * np.cos(Y) + 1j*(np.sin(3*Y) * np.cos(Y/2.)) + + fig = plt.figure(figsize=(18, 9)) + + positions = (241, 242, 243, 244, 245, 246, 247, 248) + directions = ['row']*4 + ['column']*4 + cbar_locations = ['left', 'right', 'top', 'bottom']*2 + + for position, direction, location in zip(positions, + directions, + cbar_locations): + grid = ImageGrid(fig, position, + nrows_ncols=(2, 2), + direction=direction, + cbar_location=location, + cbar_size='20%', + cbar_mode='edge') + ax1, ax2, ax3, ax4, = grid + + im1 = ax1.imshow(arr.real, cmap='spectral') + im2 = ax2.imshow(arr.imag, cmap='hot') + im3 = ax3.imshow(np.abs(arr), cmap='jet') + im4 = ax4.imshow(np.arctan2(arr.imag, arr.real), cmap='hsv') + + # Some of these colorbars will be overridden by later ones, + # depending on the direction and cbar_location + ax1.cax.colorbar(im1) + ax2.cax.colorbar(im2) + ax3.cax.colorbar(im3) + ax4.cax.colorbar(im4)