Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 65986ff

Browse files
authored
Merge pull request #18308 from l-johnston/issue_10304
Synchronize units change in Axis.set_units for shared axis
2 parents e78aee9 + 332d31d commit 65986ff

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

lib/matplotlib/axis.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,14 +1530,31 @@ def set_units(self, u):
15301530
Parameters
15311531
----------
15321532
u : units tag
1533+
1534+
Notes
1535+
-----
1536+
The units of any shared axis will also be updated.
15331537
"""
15341538
if u == self.units:
15351539
return
1536-
self.units = u
1537-
self._update_axisinfo()
1538-
self.callbacks.process('units')
1539-
self.callbacks.process('units finalize')
1540-
self.stale = True
1540+
if self is self.axes.xaxis:
1541+
shared = [
1542+
ax.xaxis
1543+
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
1544+
]
1545+
elif self is self.axes.yaxis:
1546+
shared = [
1547+
ax.yaxis
1548+
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
1549+
]
1550+
else:
1551+
shared = [self]
1552+
for axis in shared:
1553+
axis.units = u
1554+
axis._update_axisinfo()
1555+
axis.callbacks.process('units')
1556+
axis.callbacks.process('units finalize')
1557+
axis.stale = True
15411558

15421559
def get_units(self):
15431560
"""Return the units for axis."""

lib/matplotlib/tests/test_units.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from datetime import datetime
1+
from datetime import datetime, timezone, timedelta
22
import platform
33
from unittest.mock import MagicMock
44

55
import matplotlib.pyplot as plt
66
from matplotlib.testing.decorators import check_figures_equal, image_comparison
77
import matplotlib.units as munits
8+
from matplotlib.category import UnitData
89
import numpy as np
910
import pytest
1011

@@ -127,12 +128,12 @@ def test_jpl_bar_units():
127128
units.register()
128129

129130
day = units.Duration("ET", 24.0 * 60.0 * 60.0)
130-
x = [0*units.km, 1*units.km, 2*units.km]
131-
w = [1*day, 2*day, 3*day]
131+
x = [0 * units.km, 1 * units.km, 2 * units.km]
132+
w = [1 * day, 2 * day, 3 * day]
132133
b = units.Epoch("ET", dt=datetime(2009, 4, 25))
133134
fig, ax = plt.subplots()
134135
ax.bar(x, w, bottom=b)
135-
ax.set_ylim([b-1*day, b+w[-1]+(1.001)*day])
136+
ax.set_ylim([b - 1 * day, b + w[-1] + (1.001) * day])
136137

137138

138139
@image_comparison(['jpl_barh_units.png'],
@@ -142,13 +143,13 @@ def test_jpl_barh_units():
142143
units.register()
143144

144145
day = units.Duration("ET", 24.0 * 60.0 * 60.0)
145-
x = [0*units.km, 1*units.km, 2*units.km]
146-
w = [1*day, 2*day, 3*day]
146+
x = [0 * units.km, 1 * units.km, 2 * units.km]
147+
w = [1 * day, 2 * day, 3 * day]
147148
b = units.Epoch("ET", dt=datetime(2009, 4, 25))
148149

149150
fig, ax = plt.subplots()
150151
ax.barh(x, w, left=b)
151-
ax.set_xlim([b-1*day, b+w[-1]+(1.001)*day])
152+
ax.set_xlim([b - 1 * day, b + w[-1] + (1.001) * day])
152153

153154

154155
def test_empty_arrays():
@@ -172,3 +173,41 @@ class subdate(datetime):
172173

173174
fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o")
174175
fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o")
176+
177+
178+
def test_shared_axis_quantity(quantity_converter):
179+
munits.registry[Quantity] = quantity_converter
180+
x = Quantity(np.linspace(0, 1, 10), "hours")
181+
y1 = Quantity(np.linspace(1, 2, 10), "feet")
182+
y2 = Quantity(np.linspace(3, 4, 10), "feet")
183+
fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all')
184+
ax1.plot(x, y1)
185+
ax2.plot(x, y2)
186+
assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours"
187+
assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet"
188+
ax1.xaxis.set_units("seconds")
189+
ax2.yaxis.set_units("inches")
190+
assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds"
191+
assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches"
192+
193+
194+
def test_shared_axis_datetime():
195+
# datetime uses dates.DateConverter
196+
y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)]
197+
y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)]
198+
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
199+
ax1.plot(y1)
200+
ax2.plot(y2)
201+
ax1.yaxis.set_units(timezone(timedelta(hours=5)))
202+
assert ax2.yaxis.units == timezone(timedelta(hours=5))
203+
204+
205+
def test_shared_axis_categorical():
206+
# str uses category.StrCategoryConverter
207+
d1 = {"a": 1, "b": 2}
208+
d2 = {"a": 3, "b": 4}
209+
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
210+
ax1.plot(d1.keys(), d1.values())
211+
ax2.plot(d2.keys(), d2.values())
212+
ax1.xaxis.set_units(UnitData(["c", "d"]))
213+
assert "c" in ax2.xaxis.get_units()._mapping.keys()

0 commit comments

Comments
 (0)