diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 6d28fae6c2d9..0ce2f52acf54 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -1530,14 +1530,31 @@ def set_units(self, u): Parameters ---------- u : units tag + + Notes + ----- + The units of any shared axis will also be updated. """ if u == self.units: return - self.units = u - self._update_axisinfo() - self.callbacks.process('units') - self.callbacks.process('units finalize') - self.stale = True + if self is self.axes.xaxis: + shared = [ + ax.xaxis + for ax in self.axes.get_shared_x_axes().get_siblings(self.axes) + ] + elif self is self.axes.yaxis: + shared = [ + ax.yaxis + for ax in self.axes.get_shared_y_axes().get_siblings(self.axes) + ] + else: + shared = [self] + for axis in shared: + axis.units = u + axis._update_axisinfo() + axis.callbacks.process('units') + axis.callbacks.process('units finalize') + axis.stale = True def get_units(self): """Return the units for axis.""" diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 252136b4dfb4..3f40a99a2f5a 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -1,10 +1,11 @@ -from datetime import datetime +from datetime import datetime, timezone, timedelta import platform from unittest.mock import MagicMock import matplotlib.pyplot as plt from matplotlib.testing.decorators import check_figures_equal, image_comparison import matplotlib.units as munits +from matplotlib.category import UnitData import numpy as np import pytest @@ -127,12 +128,12 @@ def test_jpl_bar_units(): units.register() day = units.Duration("ET", 24.0 * 60.0 * 60.0) - x = [0*units.km, 1*units.km, 2*units.km] - w = [1*day, 2*day, 3*day] + x = [0 * units.km, 1 * units.km, 2 * units.km] + w = [1 * day, 2 * day, 3 * day] b = units.Epoch("ET", dt=datetime(2009, 4, 25)) fig, ax = plt.subplots() ax.bar(x, w, bottom=b) - ax.set_ylim([b-1*day, b+w[-1]+(1.001)*day]) + ax.set_ylim([b - 1 * day, b + w[-1] + (1.001) * day]) @image_comparison(['jpl_barh_units.png'], @@ -142,13 +143,13 @@ def test_jpl_barh_units(): units.register() day = units.Duration("ET", 24.0 * 60.0 * 60.0) - x = [0*units.km, 1*units.km, 2*units.km] - w = [1*day, 2*day, 3*day] + x = [0 * units.km, 1 * units.km, 2 * units.km] + w = [1 * day, 2 * day, 3 * day] b = units.Epoch("ET", dt=datetime(2009, 4, 25)) fig, ax = plt.subplots() ax.barh(x, w, left=b) - ax.set_xlim([b-1*day, b+w[-1]+(1.001)*day]) + ax.set_xlim([b - 1 * day, b + w[-1] + (1.001) * day]) def test_empty_arrays(): @@ -172,3 +173,41 @@ class subdate(datetime): fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o") fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o") + + +def test_shared_axis_quantity(quantity_converter): + munits.registry[Quantity] = quantity_converter + x = Quantity(np.linspace(0, 1, 10), "hours") + y1 = Quantity(np.linspace(1, 2, 10), "feet") + y2 = Quantity(np.linspace(3, 4, 10), "feet") + fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all') + ax1.plot(x, y1) + ax2.plot(x, y2) + assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours" + assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet" + ax1.xaxis.set_units("seconds") + ax2.yaxis.set_units("inches") + assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds" + assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches" + + +def test_shared_axis_datetime(): + # datetime uses dates.DateConverter + y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] + y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] + fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True) + ax1.plot(y1) + ax2.plot(y2) + ax1.yaxis.set_units(timezone(timedelta(hours=5))) + assert ax2.yaxis.units == timezone(timedelta(hours=5)) + + +def test_shared_axis_categorical(): + # str uses category.StrCategoryConverter + d1 = {"a": 1, "b": 2} + d2 = {"a": 3, "b": 4} + fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True) + ax1.plot(d1.keys(), d1.values()) + ax2.plot(d2.keys(), d2.values()) + ax1.xaxis.set_units(UnitData(["c", "d"])) + assert "c" in ax2.xaxis.get_units()._mapping.keys()