Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d7f73ab

Browse files
committed
Allow unit-ful image data
1 parent a142369 commit d7f73ab

File tree

7 files changed

+84
-14
lines changed

7 files changed

+84
-14
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Unit converters can now support units on images
2+
-----------------------------------------------
3+
4+
Matplotlib now supports using `~.axes.Axes.imshow` to plot data with units.
5+
For this to be supported by third-party `~.units.ConversionInterface`s,
6+
the `~.units.ConversionInterface.default_units` and
7+
`~.units.ConversionInterface.convert` must allow for the *axis* argument to be
8+
a ``matplotlib.images._ImageBase`` object.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
`~.axes.Axes.imshow` can now be used with unit-ful data
2+
-------------------------------------------------------
3+
4+
`~.axes.Axes.imshow` can now be used with data that has units attached to it.

examples/units/basic_units.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,16 @@ def convert(val, unit, axis):
346346
if np.iterable(val):
347347
if isinstance(val, np.ma.MaskedArray):
348348
val = val.astype(float).filled(np.nan)
349-
out = np.empty(len(val))
350-
for i, thisval in enumerate(val):
351-
if np.ma.is_masked(thisval):
352-
out[i] = np.nan
353-
else:
354-
try:
355-
out[i] = thisval.convert_to(unit).get_value()
356-
except AttributeError:
357-
out[i] = thisval
349+
out = np.empty(np.shape(val))
350+
masked_mask = np.ma.getmaskarray(val)
351+
out[masked_mask] = np.nan
352+
out[~masked_mask] = val[~masked_mask].convert_to(unit).get_value()
358353
return out
354+
359355
if np.ma.is_masked(val):
360356
return np.nan
361357
else:
358+
# Scalar
362359
return val.convert_to(unit).get_value()
363360

364361
@staticmethod

examples/units/units_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
=================
3+
Images with units
4+
=================
5+
Plotting images with units.
6+
7+
.. only:: builder_html
8+
9+
This example requires :download:`basic_units.py <basic_units.py>`
10+
"""
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from basic_units import secs
14+
15+
data = np.array([[1, 2],
16+
[3, 4]]) * secs
17+
18+
fig, ax = plt.subplots()
19+
image = ax.imshow(data)
20+
fig.colorbar(image)
21+
plt.show()

lib/matplotlib/colorbar.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,6 @@ def __init__(self, ax, mappable=None, *, cmap=None,
470470
linewidths=[0.5 * mpl.rcParams['axes.linewidth']])
471471
self.ax.add_collection(self.dividers)
472472

473-
self.locator = None
474-
self.formatter = None
475473
self.__scale = None # linear, log10 for now. Hopefully more?
476474

477475
if ticklocation == 'auto':
@@ -481,15 +479,23 @@ def __init__(self, ax, mappable=None, *, cmap=None,
481479
self.set_label(label)
482480
self._reset_locator_formatter_scale()
483481

482+
self.locator = None
484483
if np.iterable(ticks):
485484
self.locator = ticker.FixedLocator(ticks, nbins=len(ticks))
486485
else:
487486
self.locator = ticks # Handle default in _ticker()
488487

488+
self.formatter = None
489489
if isinstance(format, str):
490490
self.formatter = ticker.FormatStrFormatter(format)
491-
else:
492-
self.formatter = format # Assume it is a Formatter or None
491+
elif format is not None:
492+
self.formatter = format # Assume it is a Formatter
493+
elif hasattr(mappable, 'converter') and hasattr(mappable, 'units'):
494+
# Set from mappable if it has a converter and units
495+
info = mappable.converter.axisinfo(mappable.units, self._long_axis)
496+
if info is not None and info.majfmt is not None:
497+
self.formatter = info.majfmt
498+
493499
self.draw_all()
494500

495501
if isinstance(mappable, contour.ContourSet) and not mappable.filled:

lib/matplotlib/image.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import matplotlib.colors as mcolors
1919
import matplotlib.cm as cm
2020
import matplotlib.cbook as cbook
21+
import matplotlib.units as munits
2122
# For clarity, names from _image are given explicitly in this module:
2223
import matplotlib._image as _image
2324
# For user convenience, the names from _image are also imported into
@@ -696,6 +697,7 @@ def set_data(self, A):
696697
"""
697698
if isinstance(A, PIL.Image.Image):
698699
A = pil_to_array(A) # Needed e.g. to apply png palette.
700+
A = self._convert_units(A)
699701
self._A = cbook.safe_masked_invalid(A, copy=True)
700702

701703
if (self._A.dtype != np.uint8 and
@@ -733,6 +735,30 @@ def set_data(self, A):
733735
self._rgbacache = None
734736
self.stale = True
735737

738+
def _convert_units(self, A):
739+
# Take the first element since units expects a 1D sequence, not 2D
740+
converter = munits.registry.get_converter(A[0])
741+
if converter is None:
742+
return A
743+
744+
try:
745+
units = converter.default_units(A, self)
746+
except Exception as e:
747+
raise RuntimeError(
748+
f'{converter} failed when trying to return the default units '
749+
f'for this image. This may be because {converter} has not '
750+
'implemented support for images in the default_units() method.'
751+
) from e
752+
753+
try:
754+
return converter.convert(A, units, self)
755+
except Exception as e:
756+
raise RuntimeError(
757+
f'{converter} failed when trying to convert the units '
758+
f'for this image. This may be because {converter} has not '
759+
'implemented support for images in the convert() method.'
760+
) from e
761+
736762
def set_array(self, A):
737763
"""
738764
Retained for backwards compatibility - use set_data instead.

lib/matplotlib/units.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,12 @@ def axisinfo(unit, axis):
118118

119119
@staticmethod
120120
def default_units(x, axis):
121-
"""Return the default unit for *x* or ``None`` for the given axis."""
121+
"""
122+
Return the default unit for *x* or ``None``.
123+
124+
*axis* can be either an `Axis` or an ``_ImageBase`` (if units of a 2D
125+
image are being converted).
126+
"""
122127
return None
123128

124129
@staticmethod
@@ -128,6 +133,9 @@ def convert(obj, unit, axis):
128133
129134
If *obj* is a sequence, return the converted sequence. The output must
130135
be a sequence of scalars that can be used by the numpy array layer.
136+
137+
*axis* can be either an `Axis` or an ``_ImageBase`` (if units of a 2D
138+
image are being converted).
131139
"""
132140
return obj
133141

0 commit comments

Comments
 (0)