Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ec76fe3

Browse files
committed
Allow Mappables to have units
Add test image And unit setter and getter
1 parent 3361895 commit ec76fe3

File tree

14 files changed

+215
-34
lines changed

14 files changed

+215
-34
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Unit converters can now support units in `~.cm.ScalarMappable`s
2+
---------------------------------------------------------------
3+
4+
`~.cm.ScalarMappable` can now contain data with units.
5+
For this to be supported by third-party `~.units.ConversionInterface`s,
6+
the `~.units.ConversionInterface.default_units` and
7+
`~.units.ConversionInterface.convert` methods must allow for the *axis*
8+
argument to be a `~.cm.ScalarMappable` object, and
9+
`~.units.ConversionInterface.convert` must be able to convert data of more than
10+
one dimension (e.g. when plotting images the data is 2D).
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Colorbars and mappables with units
2+
----------------------------------
3+
If a colorbar is created with a mappable that has data with units, the
4+
mappable converter will be queried to automatically set the major formatter for
5+
the colorbar axis.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Unit support for ScalarMappables
2+
--------------------------------
3+
ScalarMappables can now have data with units set. This adds support for data
4+
with units to the following plotting methods:
5+
6+
- `~.axes.Axes.imshow`
7+
- `~.axes.Axes.contour` and - `~.axes.Axes.contourf`
8+
- `~.axes.Axes.pcolor` and `~.axes.Axes.pcolormesh`

examples/units/basic_units.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,28 @@ def convert(val, unit, axis):
346346
if np.iterable(val):
347347
if isinstance(val, np.ma.MaskedArray):
348348
val = val.astype(float).filled(np.nan)
349-
out = np.empty(len(val))
350-
for i, thisval in enumerate(val):
351-
if np.ma.is_masked(thisval):
352-
out[i] = np.nan
353-
else:
354-
try:
355-
out[i] = thisval.convert_to(unit).get_value()
356-
except AttributeError:
357-
out[i] = thisval
349+
out = np.empty(np.shape(val))
350+
if isinstance(val, np.ndarray) and isinstance(val, TaggedValue):
351+
masked_mask = np.ma.getmaskarray(val)
352+
out[masked_mask] = np.nan
353+
converted = val[~masked_mask].convert_to(unit).get_value()
354+
out[~masked_mask] = converted
355+
else:
356+
for i, thisval in enumerate(val):
357+
if np.ma.is_masked(thisval):
358+
out[i] = np.nan
359+
else:
360+
try:
361+
out[i] = thisval.convert_to(unit).get_value()
362+
except AttributeError:
363+
out[i] = thisval
364+
358365
return out
366+
359367
if np.ma.is_masked(val):
360368
return np.nan
361369
else:
370+
# Scalar
362371
return val.convert_to(unit).get_value()
363372

364373
@staticmethod

examples/units/units_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
=================
3+
Images with units
4+
=================
5+
Plotting images with units.
6+
7+
.. only:: builder_html
8+
9+
This example requires :download:`basic_units.py <basic_units.py>`
10+
"""
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from basic_units import secs
14+
15+
data = np.array([[1, 2],
16+
[3, 4]]) * secs
17+
18+
fig, ax = plt.subplots()
19+
image = ax.imshow(data)
20+
fig.colorbar(image)
21+
plt.show()

lib/matplotlib/cm.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import matplotlib as mpl
2424
from matplotlib import _api, colors, cbook
25+
import matplotlib.units as munits
2526
from matplotlib._cm import datad
2627
from matplotlib._cm_listed import cmaps as cmaps_listed
2728

@@ -344,6 +345,8 @@ def __init__(self, norm=None, cmap=None):
344345
#: The last colorbar associated with this ScalarMappable. May be None.
345346
self.colorbar = None
346347
self.callbacks = cbook.CallbackRegistry()
348+
self.units = None
349+
self.converter = None
347350

348351
callbacksSM = _api.deprecated("3.5", alternative="callbacks")(
349352
property(lambda self: self.callbacks))
@@ -456,6 +459,7 @@ def set_array(self, A):
456459
self._A = None
457460
return
458461

462+
A = self._convert_mappable_units(A)
459463
A = cbook.safe_masked_invalid(A, copy=True)
460464
if not np.can_cast(A.dtype, float, "same_kind"):
461465
raise TypeError(f"Image data of dtype {A.dtype} cannot be "
@@ -601,3 +605,70 @@ def changed(self):
601605
"""
602606
self.callbacks.process('changed', self)
603607
self.stale = True
608+
609+
@property
610+
def units(self):
611+
return self._units
612+
613+
@units.setter
614+
def units(self, unit):
615+
self._units = unit
616+
617+
# {get, set}_units are included to mimic the unit API of Axis
618+
def set_units(self, u):
619+
"""
620+
Set the units.
621+
622+
Parameters
623+
----------
624+
u : units tag
625+
"""
626+
self.units = u
627+
628+
def get_units(self):
629+
"""Return the units for axis."""
630+
return self.units
631+
632+
@property
633+
def converter(self):
634+
return self._converter
635+
636+
@converter.setter
637+
def converter(self, converter):
638+
if (converter is not None and
639+
not isinstance(converter, munits.ConversionInterface)):
640+
raise ValueError('converter must be None or an instance of '
641+
'ConversionInterface')
642+
self._converter = converter
643+
644+
def _convert_mappable_units(self, A):
645+
# If A is natively supported by Matplotlib, doesn't need converting
646+
if munits._is_natively_supported(A):
647+
return A
648+
649+
if self.converter is None:
650+
self.converter = munits.registry.get_converter(A)
651+
652+
if self.converter is None:
653+
return A
654+
655+
if self.units is None:
656+
try:
657+
self.units = self.converter.default_units(A, self)
658+
except Exception as e:
659+
raise RuntimeError(
660+
f'{self.converter} failed when trying to return the '
661+
'default units for this image. This may be because '
662+
f'{self.converter} has not implemented support for '
663+
'`ScalarMappable`s in the default_units() method.'
664+
) from e
665+
666+
try:
667+
return self.converter.convert(A, self.units, self)
668+
except Exception as e:
669+
raise munits.ConversionError(
670+
f'{self.converter} failed when trying to convert the units '
671+
f'for this image. This may be because {self.converter} has '
672+
'not implemented support for `ScalarMappable`s in the '
673+
'convert() method.'
674+
) from e

lib/matplotlib/collections.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ def __init__(self, widths, heights, angles, units='points', **kwargs):
17341734
self._widths = 0.5 * np.asarray(widths).ravel()
17351735
self._heights = 0.5 * np.asarray(heights).ravel()
17361736
self._angles = np.deg2rad(angles).ravel()
1737-
self._units = units
1737+
self._size_units = units
17381738
self.set_transform(transforms.IdentityTransform())
17391739
self._transforms = np.empty((0, 3, 3))
17401740
self._paths = [mpath.Path.unit_circle()]
@@ -1745,24 +1745,24 @@ def _set_transforms(self):
17451745
ax = self.axes
17461746
fig = self.figure
17471747

1748-
if self._units == 'xy':
1748+
if self._size_units == 'xy':
17491749
sc = 1
1750-
elif self._units == 'x':
1750+
elif self._size_units == 'x':
17511751
sc = ax.bbox.width / ax.viewLim.width
1752-
elif self._units == 'y':
1752+
elif self._size_units == 'y':
17531753
sc = ax.bbox.height / ax.viewLim.height
1754-
elif self._units == 'inches':
1754+
elif self._size_units == 'inches':
17551755
sc = fig.dpi
1756-
elif self._units == 'points':
1756+
elif self._size_units == 'points':
17571757
sc = fig.dpi / 72.0
1758-
elif self._units == 'width':
1758+
elif self._size_units == 'width':
17591759
sc = ax.bbox.width
1760-
elif self._units == 'height':
1760+
elif self._size_units == 'height':
17611761
sc = ax.bbox.height
1762-
elif self._units == 'dots':
1762+
elif self._size_units == 'dots':
17631763
sc = 1.0
17641764
else:
1765-
raise ValueError('unrecognized units: %s' % self._units)
1765+
raise ValueError('unrecognized units: %s' % self._size_units)
17661766

17671767
self._transforms = np.zeros((len(self._widths), 3, 3))
17681768
widths = self._widths * sc
@@ -1776,7 +1776,7 @@ def _set_transforms(self):
17761776
self._transforms[:, 2, 2] = 1.0
17771777

17781778
_affine = transforms.Affine2D
1779-
if self._units == 'xy':
1779+
if self._size_units == 'xy':
17801780
m = ax.transData.get_affine().get_matrix().copy()
17811781
m[:2, 2:] = 0
17821782
self.set_transform(_affine(m))

lib/matplotlib/colorbar.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@
9797
If None, ticks are determined automatically from the
9898
input.
9999
*format* None or str or Formatter
100-
If None, `~.ticker.ScalarFormatter` is used.
100+
If None, extracted from the mappable ``converter`` attribute
101+
if units are present, otherwise `~.ticker.ScalarFormatter` is
102+
used.
101103
If a format string is given, e.g., '%.3f', that is used.
102104
An alternative `~.ticker.Formatter` may be given instead.
103105
*drawedges* bool
@@ -470,9 +472,6 @@ def __init__(self, ax, mappable=None, *, cmap=None,
470472
linewidths=[0.5 * mpl.rcParams['axes.linewidth']])
471473
self.ax.add_collection(self.dividers)
472474

473-
self.locator = None
474-
self.minorlocator = None
475-
self.formatter = None
476475
self.__scale = None # linear, log10 for now. Hopefully more?
477476

478477
if ticklocation == 'auto':
@@ -489,8 +488,15 @@ def __init__(self, ax, mappable=None, *, cmap=None,
489488

490489
if isinstance(format, str):
491490
self.formatter = ticker.FormatStrFormatter(format)
492-
else:
493-
self.formatter = format # Assume it is a Formatter or None
491+
elif format is not None:
492+
self.formatter = format # Assume it is a Formatter
493+
elif mappable.converter is not None and mappable.units is not None:
494+
# Set from mappable if it has a converter and units
495+
info = mappable.converter.axisinfo(
496+
mappable.units, self._long_axis)
497+
if info is not None and info.majfmt is not None:
498+
self.formatter = info.majfmt
499+
494500
self.draw_all()
495501

496502
if isinstance(mappable, contour.ContourSet) and not mappable.filled:
@@ -1135,10 +1141,10 @@ def _reset_locator_formatter_scale(self):
11351141
need to be re-entered if this gets called (either at init, or when
11361142
the mappable normal gets changed: Colorbar.update_normal)
11371143
"""
1138-
self._process_values()
11391144
self.locator = None
11401145
self.minorlocator = None
11411146
self.formatter = None
1147+
self._process_values()
11421148
if (self.boundaries is not None or
11431149
isinstance(self.norm, colors.BoundaryNorm)):
11441150
if self.spacing == 'uniform':

lib/matplotlib/image.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import matplotlib.colors as mcolors
1919
import matplotlib.cm as cm
2020
import matplotlib.cbook as cbook
21+
import matplotlib.units as munits
2122
# For clarity, names from _image are given explicitly in this module:
2223
import matplotlib._image as _image
2324
# For user convenience, the names from _image are also imported into
@@ -699,6 +700,7 @@ def set_data(self, A):
699700
"""
700701
if isinstance(A, PIL.Image.Image):
701702
A = pil_to_array(A) # Needed e.g. to apply png palette.
703+
A = self._convert_mappable_units(A)
702704
self._A = cbook.safe_masked_invalid(A, copy=True)
703705

704706
if (self._A.dtype != np.uint8 and

lib/matplotlib/quiver.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ def __init__(self, ax, *args,
489489
self.headaxislength = headaxislength
490490
self.minshaft = minshaft
491491
self.minlength = minlength
492-
self.units = units
493492
self.scale_units = scale_units
494493
self.angles = angles
495494
self.width = width
@@ -504,6 +503,8 @@ def __init__(self, ax, *args,
504503
kw.setdefault('linewidths', (0,))
505504
super().__init__([], offsets=self.XY, transOffset=self.transform,
506505
closed=False, **kw)
506+
507+
self.units = units
507508
self.polykw = kw
508509
self.set_UVC(U, V, C)
509510
self._initialized = False
@@ -521,6 +522,14 @@ def on_dpi_change(fig):
521522

522523
self._cid = ax.figure.callbacks.connect('dpi_changed', on_dpi_change)
523524

525+
def _convert_mappable_units(self, A):
526+
"""
527+
Since Quiver already has a .units attribute for another purpose, it's
528+
not yet possible to support units on the ScalarMappable part, so
529+
override convert units to be a no-op.
530+
"""
531+
return A
532+
524533
def remove(self):
525534
# docstring inherited
526535
self.axes.figure.callbacks.disconnect(self._cid)
@@ -622,7 +631,7 @@ def _dots_per_unit(self, units):
622631
elif units == 'inches':
623632
dx = self.axes.figure.dpi
624633
else:
625-
raise ValueError('unrecognized units')
634+
raise ValueError(f'Unrecognized units: {units}')
626635
return dx
627636

628637
def _set_transform(self):

lib/matplotlib/tests/test_collections.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,7 @@ def test_collection_set_array():
684684
# Test set_array with list
685685
c = Collection()
686686
c.set_array(vals)
687-
688-
# Test set_array with wrong dtype
689-
with pytest.raises(TypeError, match="^Image data of dtype"):
690-
c.set_array("wrong_input")
687+
c.set_array("categorical_input")
691688

692689
# Test if array kwarg is copied
693690
vals[5] = 45

0 commit comments

Comments
 (0)