Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Colorbar inherit from Axes #20350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions doc/api/next_api_changes/behavior/20XXX-GL.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Colorbars are now an instance of Axes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :class:`.colorbar.Colorbar` class now inherits from `.axes.Axes`,
meaning that all of the standard methods of ``Axes`` can be used
directly on the colorbar object itself rather than having to access the
``ax`` attribute. For example, ::

cbar.set_yticks()

rather than ::

cbar.ax.set_yticks()

We are leaving the ``cbar.ax`` attribute in place as a pass-through for now,
which just maps back to the colorbar object.
4 changes: 2 additions & 2 deletions doc/api/prev_api_changes/api_changes_2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ The ``shading`` kwarg to `~matplotlib.axes.Axes.pcolor` has been
removed. Set ``edgecolors`` appropriately instead.


Functions removed from the `.lines` module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Functions removed from the `matplotlib.lines` module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :mod:`matplotlib.lines` module no longer imports the
``pts_to_prestep``, ``pts_to_midstep`` and ``pts_to_poststep``
Expand Down
2 changes: 1 addition & 1 deletion examples/axes_grid1/demo_axes_divider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def demo_simple_image(ax):

im = ax.imshow(Z, extent=extent)
cb = plt.colorbar(im)
plt.setp(cb.ax.get_yticklabels(), visible=False)
plt.setp(cb.get_yticklabels(), visible=False)


def demo_locatable_axes_hard(fig):
Expand Down
4 changes: 2 additions & 2 deletions examples/images_contours_and_fields/contour_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
# so let's improve its position.

l, b, w, h = ax.get_position().bounds
ll, bb, ww, hh = CB.ax.get_position().bounds
CB.ax.set_position([ll, b + 0.1*h, ww, h*0.8])
ll, bb, ww, hh = CB.get_position().bounds
CB.set_position([ll, b + 0.1*h, ww, h*0.8])

plt.show()

Expand Down
2 changes: 1 addition & 1 deletion examples/images_contours_and_fields/contourf_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

# Make a colorbar for the ContourSet returned by the contourf call.
cbar = fig1.colorbar(CS)
cbar.ax.set_ylabel('verbosity coefficient')
cbar.set_ylabel('verbosity coefficient')
# Add the contour line levels to the colorbar
cbar.add_lines(CS2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def heatmap(data, row_labels, col_labels, ax=None,

# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
cbar.set_ylabel(cbarlabel, rotation=-90, va="bottom")

# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
Expand Down
4 changes: 2 additions & 2 deletions examples/ticks_and_spines/colorbar_tick_labelling_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

# Add colorbar, make sure to specify tick locations to match desired ticklabels
cbar = fig.colorbar(cax, ticks=[-1, 0, 1])
cbar.ax.set_yticklabels(['< -1', '0', '> 1']) # vertically oriented colorbar
cbar.set_yticklabels(['< -1', '0', '> 1']) # vertically oriented colorbar

###############################################################################
# Make plot with horizontal colorbar
Expand All @@ -42,6 +42,6 @@
ax.set_title('Gaussian noise with horizontal colorbar')

cbar = fig.colorbar(cax, ticks=[-1, 0, 1], orientation='horizontal')
cbar.ax.set_xticklabels(['Low', 'Medium', 'High']) # horizontal colorbar
cbar.set_xticklabels(['Low', 'Medium', 'High']) # horizontal colorbar

plt.show()
102 changes: 49 additions & 53 deletions lib/matplotlib/colorbar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Colorbars are a visualization of the mapping from scalar values to colors.
In Matplotlib they are drawn into a dedicated `~.axes.Axes`.
In Matplotlib they are a separate dedicated `~.axes.Axes`.

.. note::
Colorbars are typically created through `.Figure.colorbar` or its pyplot
Expand Down Expand Up @@ -233,7 +233,7 @@ class ColorbarAxes(Axes):
over/under colors.

Users should not normally instantiate this class, but it is the class
returned by ``cbar = fig.colorbar(im); cax = cbar.ax``.
that the Colorbar inherits from by ``cbar = fig.colorbar(im);``.
"""
def __init__(self, parent, userax=True):
"""
Expand Down Expand Up @@ -310,7 +310,7 @@ def draw(self, renderer):
return ret


class ColorbarBase:
class ColorbarBase(ColorbarAxes):
r"""
Draw a colorbar in an existing axes.

Expand Down Expand Up @@ -339,8 +339,6 @@ class ColorbarBase:

Attributes
----------
ax : `~matplotlib.axes.Axes`
The `~.axes.Axes` instance in which the colorbar is drawn.
lines : list
A list of `.LineCollection` (empty if no lines were drawn).
dividers : `.LineCollection`
Expand Down Expand Up @@ -384,7 +382,7 @@ class ColorbarBase:
label : str

userax : boolean
Whether the user created the axes or not. Default True
Whether the user created the axes or not. Default False
"""

n_rasterize = 50 # rasterize solids if number of colors >= n_rasterize
Expand Down Expand Up @@ -417,9 +415,9 @@ def __init__(self, ax, *, cmap=None,
['uniform', 'proportional'], spacing=spacing)

# wrap the axes so that it can be positioned as an inset axes:
ax = ColorbarAxes(ax, userax=userax)
self.ax = ax
ax.set(navigate=False)
super().__init__(ax, userax=userax)
self.ax = self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest making this a property instead: it would make room to add a warning when this "attribute" is deprecated.

Copy link
Member

@timhoffm timhoffm Aug 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can always transparently turn an attribute into a property later if we want to add a warning. That's one of the great features of properties.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. I guess you could see this as premature :)

self.set(navigate=False)

if cmap is None:
cmap = cm.get_cmap()
Expand Down Expand Up @@ -448,24 +446,19 @@ def __init__(self, ax, *, cmap=None,
self.extendrect = extendrect
self.solids = None
self.solids_patches = []
self.lines = []
self._lines = []

for spine in self.ax.spines.values():
for spine in self.spines.values():
spine.set_visible(False)
for spine in self.ax.outer_ax.spines.values():
for spine in self.outer_ax.spines.values():
spine.set_visible(False)
self.outline = self.ax.spines['outline'] = _ColorbarSpine(self.ax)

self.patch = mpatches.Polygon(
np.empty((0, 2)),
color=mpl.rcParams['axes.facecolor'], linewidth=0.01, zorder=-1)
ax.add_artist(self.patch)
self.outline = self.spines['outline'] = _ColorbarSpine(self)

self.dividers = collections.LineCollection(
[],
colors=[mpl.rcParams['axes.edgecolor']],
linewidths=[0.5 * mpl.rcParams['axes.linewidth']])
self.ax.add_collection(self.dividers)
self.add_collection(self.dividers)

self.locator = None
self.formatter = None
Expand All @@ -489,6 +482,10 @@ def __init__(self, ax, *, cmap=None,
self.formatter = format # Assume it is a Formatter or None
self.draw_all()

@property
def lines(self):
return self._lines

def draw_all(self):
"""
Calculate any free parameters based on the current cmap and norm,
Expand Down Expand Up @@ -519,8 +516,8 @@ def draw_all(self):
# also adds the outline path to self.outline spine:
self._do_extends(extendlen)

self.ax.set_xlim(self.vmin, self.vmax)
self.ax.set_ylim(self.vmin, self.vmax)
self.set_xlim(self.vmin, self.vmax)
self.set_ylim(self.vmin, self.vmax)

# set up the tick locators and formatters. A bit complicated because
# boundary norms + uniform spacing requires a manual locator.
Expand Down Expand Up @@ -548,7 +545,7 @@ def _add_solids(self, X, Y, C):
and any(hatch is not None for hatch in mappable.hatches)):
self._add_solids_patches(X, Y, C, mappable)
else:
self.solids = self.ax.pcolormesh(
self.solids = self.pcolormesh(
X, Y, C, cmap=self.cmap, norm=self.norm, alpha=self.alpha,
edgecolors='none', shading='flat')
if not self.drawedges:
Expand All @@ -569,7 +566,7 @@ def _add_solids_patches(self, X, Y, C, mappable):
facecolor=self.cmap(self.norm(C[i][0])),
hatch=hatches[i], linewidth=0,
antialiased=False, alpha=self.alpha)
self.ax.add_patch(patch)
self.add_patch(patch)
patches.append(patch)
self.solids_patches = patches

Expand Down Expand Up @@ -605,7 +602,7 @@ def _do_extends(self, extendlen):
if self.orientation == 'horizontal':
bounds = bounds[[1, 0, 3, 2]]
xyout = xyout[:, ::-1]
self.ax._set_inner_bounds(bounds)
self._set_inner_bounds(bounds)

# xyout is the path for the spine:
self.outline.set_xy(xyout)
Expand Down Expand Up @@ -634,9 +631,9 @@ def _do_extends(self, extendlen):
color = self.cmap(self.norm(self._values[0]))
patch = mpatches.PathPatch(
mpath.Path(xy), facecolor=color, linewidth=0,
antialiased=False, transform=self.ax.outer_ax.transAxes,
antialiased=False, transform=self.outer_ax.transAxes,
hatch=hatches[0])
self.ax.outer_ax.add_patch(patch)
self.outer_ax.add_patch(patch)
if self._extend_upper():
if not self.extendrect:
# triangle
Expand All @@ -651,8 +648,8 @@ def _do_extends(self, extendlen):
patch = mpatches.PathPatch(
mpath.Path(xy), facecolor=color,
linewidth=0, antialiased=False,
transform=self.ax.outer_ax.transAxes, hatch=hatches[-1])
self.ax.outer_ax.add_patch(patch)
transform=self.outer_ax.transAxes, hatch=hatches[-1])
self.outer_ax.add_patch(patch)
return

def add_lines(self, levels, colors, linewidths, erase=True):
Expand Down Expand Up @@ -699,25 +696,24 @@ def add_lines(self, levels, colors, linewidths, erase=True):
# make a clip path that is just a linewidth bigger than the axes...
fac = np.max(linewidths) / 72
xy = np.array([[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]])
inches = self.ax.get_figure().dpi_scale_trans
inches = self.get_figure().dpi_scale_trans
# do in inches:
xy = inches.inverted().transform(self.ax.transAxes.transform(xy))
xy = inches.inverted().transform(self.transAxes.transform(xy))
xy[[0, 1, 4], 1] -= fac
xy[[2, 3], 1] += fac
# back to axes units...
xy = self.ax.transAxes.inverted().transform(inches.transform(xy))
xy = self.transAxes.inverted().transform(inches.transform(xy))
if self.orientation == 'horizontal':
xy = xy.T
col.set_clip_path(mpath.Path(xy, closed=True),
self.ax.transAxes)
self.ax.add_collection(col)
self.transAxes)
self.add_collection(col)
self.stale = True

def update_ticks(self):
"""
Setup the ticks and ticklabels. This should not be needed by users.
"""
ax = self.ax
# Get the locator and formatter; defaults to self.locator if not None.
self._get_ticker_locator_formatter()
self._long_axis().set_major_locator(self.locator)
Expand Down Expand Up @@ -832,7 +828,7 @@ def minorticks_on(self):
"""
Turn on colorbar minor ticks.
"""
self.ax.minorticks_on()
super().minorticks_on()
self.minorlocator = self._long_axis().get_minor_locator()
self._short_axis().set_minor_locator(ticker.NullLocator())

Expand Down Expand Up @@ -863,9 +859,9 @@ def set_label(self, label, *, loc=None, **kwargs):
Supported keywords are *labelpad* and `.Text` properties.
"""
if self.orientation == "vertical":
self.ax.set_ylabel(label, loc=loc, **kwargs)
self.set_ylabel(label, loc=loc, **kwargs)
else:
self.ax.set_xlabel(label, loc=loc, **kwargs)
self.set_xlabel(label, loc=loc, **kwargs)
self.stale = True

def set_alpha(self, alpha):
Expand All @@ -874,8 +870,8 @@ def set_alpha(self, alpha):

def remove(self):
"""Remove this colorbar from the figure."""
self.ax.inner_ax.remove()
self.ax.outer_ax.remove()
self.inner_ax.remove()
self.outer_ax.remove()

def _ticker(self, locator, formatter):
"""
Expand Down Expand Up @@ -1009,29 +1005,29 @@ def _reset_locator_formatter_scale(self):
isinstance(self.norm, colors.BoundaryNorm)):
if self.spacing == 'uniform':
funcs = (self._forward_boundaries, self._inverse_boundaries)
self.ax.set_xscale('function', functions=funcs)
self.ax.set_yscale('function', functions=funcs)
self.set_xscale('function', functions=funcs)
self.set_yscale('function', functions=funcs)
self.__scale = 'function'
elif self.spacing == 'proportional':
self.__scale = 'linear'
self.ax.set_xscale('linear')
self.ax.set_yscale('linear')
self.set_xscale('linear')
self.set_yscale('linear')
elif hasattr(self.norm, '_scale') and self.norm._scale is not None:
# use the norm's scale:
self.ax.set_xscale(self.norm._scale)
self.ax.set_yscale(self.norm._scale)
self.set_xscale(self.norm._scale)
self.set_yscale(self.norm._scale)
self.__scale = self.norm._scale.name
elif type(self.norm) is colors.Normalize:
# plain Normalize:
self.ax.set_xscale('linear')
self.ax.set_yscale('linear')
self.set_xscale('linear')
self.set_yscale('linear')
self.__scale = 'linear'
else:
# norm._scale is None or not an attr: derive the scale from
# the Norm:
funcs = (self.norm, self.norm.inverse)
self.ax.set_xscale('function', functions=funcs)
self.ax.set_yscale('function', functions=funcs)
self.set_xscale('function', functions=funcs)
self.set_yscale('function', functions=funcs)
self.__scale = 'function'

def _locate(self, x):
Expand Down Expand Up @@ -1140,14 +1136,14 @@ def _extend_upper(self):
def _long_axis(self):
"""Return the long axis"""
if self.orientation == 'vertical':
return self.ax.yaxis
return self.ax.xaxis
return self.yaxis
return self.xaxis

def _short_axis(self):
"""Return the short axis"""
if self.orientation == 'vertical':
return self.ax.xaxis
return self.ax.yaxis
return self.xaxis
return self.yaxis


class Colorbar(ColorbarBase):
Expand Down
Loading