Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e1a8fee

Browse files
committed
Basic support for units on ScalarMappable
1 parent f007886 commit e1a8fee

File tree

6 files changed

+99
-3
lines changed

6 files changed

+99
-3
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Unit converters can now support units in images
2+
-----------------------------------------------
3+
4+
`~.cm.ScalarMappable` can now contain data with units. This adds support for
5+
unit-ful data to be plotted using - `~.axes.Axes.imshow`, `~.axes.Axes.pcolor`,
6+
and `~.axes.Axes.pcolormesh`
7+
8+
For this to be supported by third-party `~.units.ConversionInterface`,
9+
the `~.units.ConversionInterface.default_units` and
10+
`~.units.ConversionInterface.convert` methods must allow for the *axis*
11+
argument to be ``None``, and `~.units.ConversionInterface.convert` must be able to
12+
convert data of more than one dimension (e.g. when plotting images the data is 2D).
13+
14+
If a conversion interface raises an error when given ``None`` or 2D data as described
15+
above, this error will be re-raised when a user tries to use one of the newly supported
16+
plotting methods with unit-ful data.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Unit support for images
2+
-----------------------
3+
This adds support for image data with units has been added to the following plotting
4+
methods:
5+
6+
- `~.axes.Axes.imshow`
7+
- `~.axes.Axes.pcolor`
8+
- `~.axes.Axes.pcolormesh`

lib/matplotlib/cm.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from matplotlib import _api, colors, cbook, scale
2525
from matplotlib._cm import datad
2626
from matplotlib._cm_listed import cmaps as cmaps_listed
27+
import matplotlib.units as munits
2728

2829

2930
_LUTSIZE = mpl.rcParams['image.lut']
@@ -283,6 +284,8 @@ def __init__(self, norm=None, cmap=None):
283284
The colormap used to map normalized data values to RGBA colors.
284285
"""
285286
self._A = None
287+
self._units = None
288+
self._converter = None
286289
self._norm = None # So that the setter knows we're initializing.
287290
self.set_norm(norm) # The Normalize instance of this ScalarMappable.
288291
self.cmap = None # So that the setter knows we're initializing.
@@ -393,6 +396,35 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
393396
rgba = self.cmap(x, alpha=alpha, bytes=bytes)
394397
return rgba
395398

399+
def _strip_units(self, A):
400+
"""
401+
Remove units from A, and save the units and converter used to do the conversion.
402+
"""
403+
self._converter = munits.registry.get_converter(A)
404+
if self._converter is None:
405+
self._units = None
406+
return A
407+
408+
try:
409+
self._units = self._converter.default_units(A, None)
410+
except Exception as e:
411+
raise RuntimeError(
412+
f'{self._converter} failed when trying to return the default units for '
413+
'this image. This may be because support has not been '
414+
'implemented for `axis=None` in the default_units() method.'
415+
) from e
416+
417+
try:
418+
A = self._converter.convert(A, self._units, None)
419+
except Exception as e:
420+
raise munits.ConversionError(
421+
f'{self._converter} failed when trying to convert the units for this '
422+
'image. This may be because support has not been implemented '
423+
'for `axis=None` in the convert() method.'
424+
) from e
425+
426+
return A
427+
396428
def set_array(self, A):
397429
"""
398430
Set the value array from array-like *A*.
@@ -408,7 +440,7 @@ def set_array(self, A):
408440
if A is None:
409441
self._A = None
410442
return
411-
443+
A = self._strip_units(A)
412444
A = cbook.safe_masked_invalid(A, copy=True)
413445
if not np.can_cast(A.dtype, float, "same_kind"):
414446
raise TypeError(f"Image data of dtype {A.dtype} cannot be "

lib/matplotlib/image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,8 @@ def set_data(self, A):
726726
"""
727727
if isinstance(A, PIL.Image.Image):
728728
A = pil_to_array(A) # Needed e.g. to apply png palette.
729+
730+
A = self._strip_units(A)
729731
self._A = self._normalize_image_array(A)
730732
self._imcache = None
731733
self.stale = True
@@ -1140,6 +1142,7 @@ def set_data(self, x, y, A):
11401142
(M, N) `~numpy.ndarray` or masked array of values to be
11411143
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11421144
"""
1145+
A = self._strip_units(A)
11431146
A = self._normalize_image_array(A)
11441147
x = np.array(x, np.float32)
11451148
y = np.array(y, np.float32)
@@ -1300,6 +1303,7 @@ def set_data(self, x, y, A):
13001303
- (M, N, 3): RGB array
13011304
- (M, N, 4): RGBA array
13021305
"""
1306+
A = self._strip_units(A)
13031307
A = self._normalize_image_array(A)
13041308
x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel()
13051309
y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel()

lib/matplotlib/tests/test_units.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __getitem__(self, item):
4141
def __array__(self):
4242
return np.asarray(self.magnitude)
4343

44+
def __len__(self):
45+
return len(self.__array__())
46+
4447

4548
@pytest.fixture
4649
def quantity_converter():
@@ -302,3 +305,30 @@ def test_plot_kernel():
302305
# just a smoketest that fail
303306
kernel = Kernel([1, 2, 3, 4, 5])
304307
plt.plot(kernel)
308+
309+
310+
@image_comparison(['mappable_units.png'], style="mpl20")
311+
def test_mappable_units(quantity_converter):
312+
# Check that showing an image with units works
313+
munits.registry[Quantity] = quantity_converter
314+
x, y = np.meshgrid([0, 1], [0, 1])
315+
data = Quantity(np.arange(4).reshape(2, 2), 'hours')
316+
317+
fig, axs = plt.subplots(nrows=2, ncols=2)
318+
319+
# imshow
320+
ax = axs[0, 0]
321+
mappable = ax.imshow(data, origin='lower')
322+
cbar = fig.colorbar(mappable, ax=ax)
323+
324+
# pcolor
325+
ax = axs[0, 1]
326+
mappable = ax.pcolor(x, y, data)
327+
fig.colorbar(mappable, ax=ax)
328+
329+
# pcolormesh + horizontal colorbar
330+
ax = axs[1, 0]
331+
mappable = ax.pcolormesh(x, y, data)
332+
fig.colorbar(mappable, ax=ax, orientation="horizontal")
333+
334+
axs[1, 1].axis("off")

lib/matplotlib/units.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,22 @@ def axisinfo(unit, axis):
118118

119119
@staticmethod
120120
def default_units(x, axis):
121-
"""Return the default unit for *x* or ``None`` for the given axis."""
121+
"""
122+
Return the default unit for *x*.
123+
124+
*axis* may be an `~.axis.Axis` or ``None``.
125+
"""
122126
return None
123127

124128
@staticmethod
125129
def convert(obj, unit, axis):
126130
"""
127-
Convert *obj* using *unit* for the specified *axis*.
131+
Convert *obj* using *unit*.
128132
129133
If *obj* is a sequence, return the converted sequence. The output must
130134
be a sequence of scalars that can be used by the numpy array layer.
135+
136+
*axis* may be an `~.axis.Axis` or ``None``.
131137
"""
132138
return obj
133139

0 commit comments

Comments
 (0)