From f77c0a57fbeb3059046dea72b724e63954d6f424 Mon Sep 17 00:00:00 2001 From: Lee Johnston Date: Thu, 20 Aug 2020 10:47:09 -0500 Subject: [PATCH 1/4] Add set_xunits and set_yunits --- lib/matplotlib/axes/_base.py | 34 ++++++++++++++++++++++++++++++ lib/matplotlib/tests/test_units.py | 25 ++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 359adb297770..260d1d1cb698 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -4305,3 +4305,37 @@ def get_shared_x_axes(self): def get_shared_y_axes(self): """Return a reference to the shared axes Grouper object for y axes.""" return self._shared_y_axes + + def set_xunits(self, units, emit=True): + """ + Set the x-axis units. + + Parameters + ---------- + units : units tag + + emit : bool, default: True + Whether to notify observers of units change. + """ + if emit: + for ax in self._shared_x_axes.get_siblings(self): + ax.xaxis.set_units(units) + else: + self.xaxis.set_units(units) + + def set_yunits(self, units, emit=True): + """ + Set the y-axis units. + + Parameters + ---------- + units : units tag + + emit : bool, default: True + Whether to notify observers of units change. + """ + if emit: + for ax in self._shared_y_axes.get_siblings(self): + ax.yaxis.set_units(units) + else: + self.yaxis.set_units(units) diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 252136b4dfb4..3060b2dd98e6 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -172,3 +172,28 @@ class subdate(datetime): fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o") fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o") + + +def test_set_xyunits(quantity_converter): + munits.registry[Quantity] = quantity_converter + x = Quantity(np.linspace(0, 1, 10), "hours") + y1 = Quantity(np.linspace(1, 2, 10), "feet") + y2 = Quantity(np.linspace(3, 4, 10), "feet") + fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all') + ax1.plot(x, y1) + ax2.plot(x, y2) + assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours" + assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet" + ax1.set_xunits("seconds") + ax2.set_yunits("inches") + assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds" + assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches" + fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all') + ax1.plot(x, y1) + ax2.plot(x, y2) + ax1.set_xunits("seconds", emit=False) + ax2.set_yunits("inches", emit=False) + assert ax1.xaxis.get_units() == "seconds" + assert ax2.xaxis.get_units() == "hours" + assert ax1.yaxis.get_units() == "feet" + assert ax2.yaxis.get_units() == "inches" From 6d54d3a2d97f6124205c1ca49f74e02d0fcf59e5 Mon Sep 17 00:00:00 2001 From: Lee Johnston Date: Fri, 21 Aug 2020 07:32:15 -0500 Subject: [PATCH 2/4] Synchronize units change in Axis.set_units for shared axis --- lib/matplotlib/axes/_base.py | 34 ------------------------------ lib/matplotlib/axis.py | 27 +++++++++++++++++++----- lib/matplotlib/tests/test_units.py | 13 ++---------- 3 files changed, 24 insertions(+), 50 deletions(-) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 260d1d1cb698..359adb297770 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -4305,37 +4305,3 @@ def get_shared_x_axes(self): def get_shared_y_axes(self): """Return a reference to the shared axes Grouper object for y axes.""" return self._shared_y_axes - - def set_xunits(self, units, emit=True): - """ - Set the x-axis units. - - Parameters - ---------- - units : units tag - - emit : bool, default: True - Whether to notify observers of units change. - """ - if emit: - for ax in self._shared_x_axes.get_siblings(self): - ax.xaxis.set_units(units) - else: - self.xaxis.set_units(units) - - def set_yunits(self, units, emit=True): - """ - Set the y-axis units. - - Parameters - ---------- - units : units tag - - emit : bool, default: True - Whether to notify observers of units change. - """ - if emit: - for ax in self._shared_y_axes.get_siblings(self): - ax.yaxis.set_units(units) - else: - self.yaxis.set_units(units) diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 6d28fae6c2d9..0ce2f52acf54 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -1530,14 +1530,31 @@ def set_units(self, u): Parameters ---------- u : units tag + + Notes + ----- + The units of any shared axis will also be updated. """ if u == self.units: return - self.units = u - self._update_axisinfo() - self.callbacks.process('units') - self.callbacks.process('units finalize') - self.stale = True + if self is self.axes.xaxis: + shared = [ + ax.xaxis + for ax in self.axes.get_shared_x_axes().get_siblings(self.axes) + ] + elif self is self.axes.yaxis: + shared = [ + ax.yaxis + for ax in self.axes.get_shared_y_axes().get_siblings(self.axes) + ] + else: + shared = [self] + for axis in shared: + axis.units = u + axis._update_axisinfo() + axis.callbacks.process('units') + axis.callbacks.process('units finalize') + axis.stale = True def get_units(self): """Return the units for axis.""" diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 3060b2dd98e6..b349f03f2c2a 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -184,16 +184,7 @@ def test_set_xyunits(quantity_converter): ax2.plot(x, y2) assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours" assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet" - ax1.set_xunits("seconds") - ax2.set_yunits("inches") + ax1.xaxis.set_units("seconds") + ax2.yaxis.set_units("inches") assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds" assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches" - fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all') - ax1.plot(x, y1) - ax2.plot(x, y2) - ax1.set_xunits("seconds", emit=False) - ax2.set_yunits("inches", emit=False) - assert ax1.xaxis.get_units() == "seconds" - assert ax2.xaxis.get_units() == "hours" - assert ax1.yaxis.get_units() == "feet" - assert ax2.yaxis.get_units() == "inches" From fb76ed7464fc876a1492df52dd3f7dc07f397220 Mon Sep 17 00:00:00 2001 From: Lee Johnston Date: Fri, 21 Aug 2020 14:01:48 -0500 Subject: [PATCH 3/4] Add test for datetime and categorical --- lib/matplotlib/tests/test_units.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index b349f03f2c2a..1891fbbda436 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -1,10 +1,11 @@ -from datetime import datetime +from datetime import datetime, timezone, timedelta import platform from unittest.mock import MagicMock import matplotlib.pyplot as plt from matplotlib.testing.decorators import check_figures_equal, image_comparison import matplotlib.units as munits +from matplotlib.category import UnitData import numpy as np import pytest @@ -188,3 +189,17 @@ def test_set_xyunits(quantity_converter): ax2.yaxis.set_units("inches") assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds" assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches" + y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] + y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] + fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True) + ax1.plot(y1) + ax2.plot(y2) + ax1.yaxis.set_units(timezone(timedelta(hours=5))) + assert ax2.yaxis.units == timezone(timedelta(hours=5)) + d1 = {"a": 1, "b": 2} + d2 = {"a": 3, "b": 4} + fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True) + ax1.plot(d1.keys(), d1.values()) + ax2.plot(d2.keys(), d2.values()) + ax1.xaxis.set_units(UnitData(["c", "d"])) + assert "c" in ax2.xaxis.get_units()._mapping.keys() From 332d31d6d1b6a777635186562f92c9c9fa69b530 Mon Sep 17 00:00:00 2001 From: Lee Johnston Date: Mon, 24 Aug 2020 13:43:25 -0500 Subject: [PATCH 4/4] Separate shared_axis tests --- lib/matplotlib/tests/test_units.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 1891fbbda436..3f40a99a2f5a 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -128,12 +128,12 @@ def test_jpl_bar_units(): units.register() day = units.Duration("ET", 24.0 * 60.0 * 60.0) - x = [0*units.km, 1*units.km, 2*units.km] - w = [1*day, 2*day, 3*day] + x = [0 * units.km, 1 * units.km, 2 * units.km] + w = [1 * day, 2 * day, 3 * day] b = units.Epoch("ET", dt=datetime(2009, 4, 25)) fig, ax = plt.subplots() ax.bar(x, w, bottom=b) - ax.set_ylim([b-1*day, b+w[-1]+(1.001)*day]) + ax.set_ylim([b - 1 * day, b + w[-1] + (1.001) * day]) @image_comparison(['jpl_barh_units.png'], @@ -143,13 +143,13 @@ def test_jpl_barh_units(): units.register() day = units.Duration("ET", 24.0 * 60.0 * 60.0) - x = [0*units.km, 1*units.km, 2*units.km] - w = [1*day, 2*day, 3*day] + x = [0 * units.km, 1 * units.km, 2 * units.km] + w = [1 * day, 2 * day, 3 * day] b = units.Epoch("ET", dt=datetime(2009, 4, 25)) fig, ax = plt.subplots() ax.barh(x, w, left=b) - ax.set_xlim([b-1*day, b+w[-1]+(1.001)*day]) + ax.set_xlim([b - 1 * day, b + w[-1] + (1.001) * day]) def test_empty_arrays(): @@ -175,7 +175,7 @@ class subdate(datetime): fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o") -def test_set_xyunits(quantity_converter): +def test_shared_axis_quantity(quantity_converter): munits.registry[Quantity] = quantity_converter x = Quantity(np.linspace(0, 1, 10), "hours") y1 = Quantity(np.linspace(1, 2, 10), "feet") @@ -189,6 +189,10 @@ def test_set_xyunits(quantity_converter): ax2.yaxis.set_units("inches") assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds" assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches" + + +def test_shared_axis_datetime(): + # datetime uses dates.DateConverter y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True) @@ -196,6 +200,10 @@ def test_set_xyunits(quantity_converter): ax2.plot(y2) ax1.yaxis.set_units(timezone(timedelta(hours=5))) assert ax2.yaxis.units == timezone(timedelta(hours=5)) + + +def test_shared_axis_categorical(): + # str uses category.StrCategoryConverter d1 = {"a": 1, "b": 2} d2 = {"a": 3, "b": 4} fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)