Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b786a55

Browse files
committed
Basic support for units on ScalarMappable
1 parent e4e6840 commit b786a55

File tree

6 files changed

+99
-3
lines changed

6 files changed

+99
-3
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Unit converters can now support units in images
2+
-----------------------------------------------
3+
4+
`~.cm.ScalarMappable` can now contain data with units. This adds support for
5+
unit-ful data to be plotted using - `~.axes.Axes.imshow`, `~.axes.Axes.pcolor`,
6+
and `~.axes.Axes.pcolormesh`
7+
8+
For this to be supported by third-party `~.units.ConversionInterface`,
9+
the `~.units.ConversionInterface.default_units` and
10+
`~.units.ConversionInterface.convert` methods must allow for the *axis*
11+
argument to be ``None``, and `~.units.ConversionInterface.convert` must be able to
12+
convert data of more than one dimension (e.g. when plotting images the data is 2D).
13+
14+
If a conversion interface raises an error when given ``None`` or 2D data as described
15+
above, this error will be re-raised when a user tries to use one of the newly supported
16+
plotting methods with unit-ful data.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Unit support for images
2+
-----------------------
3+
This adds support for image data with units has been added to the following plotting
4+
methods:
5+
6+
- `~.axes.Axes.imshow`
7+
- `~.axes.Axes.pcolor`
8+
- `~.axes.Axes.pcolormesh`

lib/matplotlib/cm.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from matplotlib import _api, colors, cbook, scale
2525
from matplotlib._cm import datad
2626
from matplotlib._cm_listed import cmaps as cmaps_listed
27+
import matplotlib.units as munits
2728

2829

2930
_LUTSIZE = mpl.rcParams['image.lut']
@@ -283,6 +284,8 @@ def __init__(self, norm=None, cmap=None):
283284
The colormap used to map normalized data values to RGBA colors.
284285
"""
285286
self._A = None
287+
self._units = None
288+
self._converter = None
286289
self._norm = None # So that the setter knows we're initializing.
287290
self.set_norm(norm) # The Normalize instance of this ScalarMappable.
288291
self.cmap = None # So that the setter knows we're initializing.
@@ -387,6 +390,35 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
387390
rgba = self.cmap(x, alpha=alpha, bytes=bytes)
388391
return rgba
389392

393+
def _strip_units(self, A):
394+
"""
395+
Remove units from A, and save the units and converter used to do the conversion.
396+
"""
397+
self._converter = munits.registry.get_converter(A)
398+
if self._converter is None:
399+
self._units = None
400+
return A
401+
402+
try:
403+
self._units = self._converter.default_units(A, None)
404+
except Exception as e:
405+
raise RuntimeError(
406+
f'{self._converter} failed when trying to return the default units for '
407+
'this image. This may be because support has not been '
408+
'implemented for `axis=None` in the default_units() method.'
409+
) from e
410+
411+
try:
412+
A = self._converter.convert(A, self._units, None)
413+
except Exception as e:
414+
raise munits.ConversionError(
415+
f'{self._converter} failed when trying to convert the units for this '
416+
'image. This may be because support has not been implemented '
417+
'for `axis=None` in the convert() method.'
418+
) from e
419+
420+
return A
421+
390422
def set_array(self, A):
391423
"""
392424
Set the value array from array-like *A*.
@@ -402,7 +434,7 @@ def set_array(self, A):
402434
if A is None:
403435
self._A = None
404436
return
405-
437+
A = self._strip_units(A)
406438
A = cbook.safe_masked_invalid(A, copy=True)
407439
if not np.can_cast(A.dtype, float, "same_kind"):
408440
raise TypeError(f"Image data of dtype {A.dtype} cannot be "

lib/matplotlib/image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,8 @@ def set_data(self, A):
728728
"""
729729
if isinstance(A, PIL.Image.Image):
730730
A = pil_to_array(A) # Needed e.g. to apply png palette.
731+
732+
A = self._strip_units(A)
731733
self._A = self._normalize_image_array(A)
732734
self._imcache = None
733735
self.stale = True
@@ -1142,6 +1144,7 @@ def set_data(self, x, y, A):
11421144
(M, N) `~numpy.ndarray` or masked array of values to be
11431145
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11441146
"""
1147+
A = self._strip_units(A)
11451148
A = self._normalize_image_array(A)
11461149
x = np.array(x, np.float32)
11471150
y = np.array(y, np.float32)
@@ -1302,6 +1305,7 @@ def set_data(self, x, y, A):
13021305
- (M, N, 3): RGB array
13031306
- (M, N, 4): RGBA array
13041307
"""
1308+
A = self._strip_units(A)
13051309
A = self._normalize_image_array(A)
13061310
x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel()
13071311
y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel()

lib/matplotlib/tests/test_units.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __getitem__(self, item):
4141
def __array__(self):
4242
return np.asarray(self.magnitude)
4343

44+
def __len__(self):
45+
return len(self.__array__())
46+
4447

4548
@pytest.fixture
4649
def quantity_converter():
@@ -294,3 +297,30 @@ def test_plot_kernel():
294297
# just a smoketest that fail
295298
kernel = Kernel([1, 2, 3, 4, 5])
296299
plt.plot(kernel)
300+
301+
302+
@image_comparison(['mappable_units.png'], style="mpl20")
303+
def test_mappable_units(quantity_converter):
304+
# Check that showing an image with units works
305+
munits.registry[Quantity] = quantity_converter
306+
x, y = np.meshgrid([0, 1], [0, 1])
307+
data = Quantity(np.arange(4).reshape(2, 2), 'hours')
308+
309+
fig, axs = plt.subplots(nrows=2, ncols=2)
310+
311+
# imshow
312+
ax = axs[0, 0]
313+
mappable = ax.imshow(data, origin='lower')
314+
cbar = fig.colorbar(mappable, ax=ax)
315+
316+
# pcolor
317+
ax = axs[0, 1]
318+
mappable = ax.pcolor(x, y, data)
319+
fig.colorbar(mappable, ax=ax)
320+
321+
# pcolormesh + horizontal colorbar
322+
ax = axs[1, 0]
323+
mappable = ax.pcolormesh(x, y, data)
324+
fig.colorbar(mappable, ax=ax, orientation="horizontal")
325+
326+
axs[1, 1].axis("off")

lib/matplotlib/units.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,22 @@ def axisinfo(unit, axis):
118118

119119
@staticmethod
120120
def default_units(x, axis):
121-
"""Return the default unit for *x* or ``None`` for the given axis."""
121+
"""
122+
Return the default unit for *x*.
123+
124+
*axis* may be an `~.axis.Axis` or ``None``.
125+
"""
122126
return None
123127

124128
@staticmethod
125129
def convert(obj, unit, axis):
126130
"""
127-
Convert *obj* using *unit* for the specified *axis*.
131+
Convert *obj* using *unit*.
128132
129133
If *obj* is a sequence, return the converted sequence. The output must
130134
be a sequence of scalars that can be used by the numpy array layer.
135+
136+
*axis* may be an `~.axis.Axis` or ``None``.
131137
"""
132138
return obj
133139

0 commit comments

Comments
 (0)