Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cdc782d

Browse files
committed
minor fixes for MultivarColormap and BivarColormap
removal of ColormapBase Addition of _repr_png_() for MultivarColormap addition of _get_rgba_and_mask() for Colormap to clean up __call__()
1 parent c775047 commit cdc782d

File tree

3 files changed

+77
-44
lines changed

3 files changed

+77
-44
lines changed

lib/matplotlib/colors.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -675,16 +675,7 @@ def _create_lookup_table(N, data, gamma=1.0):
675675
return np.clip(lut, 0.0, 1.0)
676676

677677

678-
class ColormapBase:
679-
"""
680-
Base class for all colormaps, both scalar, bivariate and multivariate.
681-
682-
This class is used for type checking, and cannot be initialized.
683-
"""
684-
...
685-
686-
687-
class Colormap(ColormapBase):
678+
class Colormap:
688679
"""
689680
Baseclass for all scalar to RGBA mappings.
690681
@@ -721,7 +712,7 @@ def __init__(self, name, N=256):
721712
#: `matplotlib.colorbar.Colorbar` constructor.
722713
self.colorbar_extend = False
723714

724-
def __call__(self, X, alpha=None, bytes=False, return_mask_bad=False):
715+
def __call__(self, X, alpha=None, bytes=False):
725716
r"""
726717
Parameters
727718
----------
@@ -737,15 +728,40 @@ def __call__(self, X, alpha=None, bytes=False, return_mask_bad=False):
737728
bytes : bool
738729
If False (default), the returned RGBA values will be floats in the
739730
interval ``[0, 1]`` otherwise they will be `numpy.uint8`\s in the
740-
interval ``[0, 255]``.
741-
return_mask_bad : bool
742-
If true, also return a mask of bad values.
731+
interval ``[0, 255]``
743732
744733
Returns
745734
-------
746735
Tuple of RGBA values if X is scalar, otherwise an array of
747736
RGBA values with a shape of ``X.shape + (4, )``.
748737
"""
738+
rgba, mask = self._get_rgba_and_mask(X, alpha=alpha, bytes=bytes)
739+
return rgba
740+
741+
def _get_rgba_and_mask(self, X, alpha=None, bytes=False):
742+
r"""
743+
Parameters
744+
----------
745+
X : float or int, `~numpy.ndarray` or scalar
746+
The data value(s) to convert to RGBA.
747+
For floats, *X* should be in the interval ``[0.0, 1.0]`` to
748+
return the RGBA values ``X*100`` percent along the Colormap line.
749+
For integers, *X* should be in the interval ``[0, Colormap.N)`` to
750+
return RGBA values *indexed* from the Colormap with index ``X``.
751+
alpha : float or array-like or None
752+
Alpha must be a scalar between 0 and 1, a sequence of such
753+
floats with shape matching X, or None.
754+
bytes : bool
755+
If False (default), the returned RGBA values will be floats in the
756+
interval ``[0, 1]`` otherwise they will be `numpy.uint8`\s in the
757+
interval ``[0, 255]``.
758+
759+
Returns
760+
-------
761+
(colors, mask), where color is a tuple of RGBA values if X is scalar,
762+
otherwise an array of RGBA values with a shape of ``X.shape + (4, )``,
763+
and mask is a boolean array.
764+
"""
749765
if not self._isinit:
750766
self._init()
751767

@@ -791,9 +807,7 @@ def __call__(self, X, alpha=None, bytes=False, return_mask_bad=False):
791807

792808
if not np.iterable(X):
793809
rgba = tuple(rgba)
794-
if return_mask_bad:
795-
return rgba, mask_bad
796-
return rgba
810+
return rgba, mask_bad
797811

798812
def __copy__(self):
799813
cls = self.__class__
@@ -1240,13 +1254,10 @@ def reversed(self, name=None):
12401254
return new_cmap
12411255

12421256

1243-
class MultivarColormap(ColormapBase):
1257+
class MultivarColormap:
12441258
"""
12451259
Class for holding multiple `~matplotlib.colors.Colormap` for use in a
12461260
`~matplotlib.cm.VectorMappable` object
1247-
1248-
MultivarColormap does not support alpha in the constituent
1249-
look up tables (ignored).
12501261
"""
12511262
def __init__(self, name, colormaps, combination_mode):
12521263
"""
@@ -1258,20 +1269,26 @@ def __init__(self, name, colormaps, combination_mode):
12581269
The individual colormaps that are combined
12591270
combination_mode: str, 'Add' or 'Sub'
12601271
Describe how colormaps are combined in sRGB space
1261-
1272+
12621273
- If 'Add' -> Mixing produces brighter colors
12631274
`sRGB = cmap[0][X[0]] + cmap[1][x[1]] + ... + cmap[n-1][x[n-1]]`
12641275
- If 'Sub' -> Mixing produces darker colors
12651276
`sRGB = cmap[0][X[0]] + cmap[1][x[1]] + ... + cmap[n-1][x[n-1]] - n + 1`
12661277
"""
12671278
self.name = name
12681279

1269-
if not np.iterable(colormaps) or len(colormaps) == 1:
1280+
if not np.iterable(colormaps) \
1281+
or len(colormaps) == 1 \
1282+
or isinstance(colormaps, str):
12701283
raise ValueError("A MultivarColormap must have more than one colormap.")
1271-
for cmap in colormaps:
1284+
colormaps = list(colormaps) # ensure cmaps is a list, i.e. not a tuple
1285+
for i, cmap in enumerate(colormaps):
12721286
if not issubclass(type(cmap), Colormap):
1273-
raise ValueError("colormaps must be a list of objects that subclass"
1274-
" Colormap, not strings or list of strings")
1287+
if isinstance(cmap, str):
1288+
colormaps[i] = mpl.colormaps[cmap]
1289+
else:
1290+
raise ValueError("colormaps must be a list of objects that subclass"
1291+
" Colormap or valid strings.")
12751292

12761293
self.colormaps = colormaps
12771294
self.combination_mode = combination_mode
@@ -1309,10 +1326,10 @@ def __call__(self, X, alpha=None, bytes=False):
13091326
raise ValueError(
13101327
f'For the selected colormap the data must have a first dimension '
13111328
f'{len(self)}, not {len(X)}')
1312-
rgba, mask_bad = self[0](X[0], bytes=False, return_mask_bad=True)
1329+
rgba, mask_bad = self[0]._get_rgba_and_mask(X[0], bytes=False)
13131330
rgba = np.asarray(rgba)
13141331
for c, xx in zip(self[1:], X[1:]):
1315-
sub_rgba, sub_mask_bad = c(xx, bytes=False, return_mask_bad=True)
1332+
sub_rgba, sub_mask_bad = c._get_rgba_and_mask(xx, bytes=False)
13161333
sub_rgba = np.asarray(sub_rgba)
13171334
rgba[..., :3] += sub_rgba[..., :3] # add colors
13181335
rgba[..., 3] *= sub_rgba[..., 3] # multiply alpha
@@ -1400,16 +1417,30 @@ def combination_mode(self, mode):
14001417
self._combination_mode = mode
14011418

14021419
def _repr_png_(self):
1403-
raise NotImplementedError("no png representation of MultivarColormap"
1404-
" but you may access png repreesntations of the"
1405-
" individual colorbars.")
1420+
"""Generate a PNG representation of the Colormap."""
1421+
X = np.tile(np.linspace(0, 1, _REPR_PNG_SIZE[0]),
1422+
(_REPR_PNG_SIZE[1], 1))
1423+
pixels = np.zeros((_REPR_PNG_SIZE[1]*len(self), _REPR_PNG_SIZE[0], 4),
1424+
dtype=np.uint8)
1425+
for i, c in enumerate(self):
1426+
pixels[i*_REPR_PNG_SIZE[1]:(i+1)*_REPR_PNG_SIZE[1], :] = c(X, bytes=True)
1427+
png_bytes = io.BytesIO()
1428+
title = self.name + ' multivariate colormap'
1429+
author = f'Matplotlib v{mpl.__version__}, https://matplotlib.org'
1430+
pnginfo = PngInfo()
1431+
pnginfo.add_text('Title', title)
1432+
pnginfo.add_text('Description', title)
1433+
pnginfo.add_text('Author', author)
1434+
pnginfo.add_text('Software', author)
1435+
Image.fromarray(pixels).save(png_bytes, format='png', pnginfo=pnginfo)
1436+
return png_bytes.getvalue()
14061437

14071438
def _repr_html_(self):
14081439
"""Generate an HTML representation of the MultivarColormap."""
14091440
return ''.join([c._repr_html_() for c in self.colormaps])
14101441

14111442

1412-
class BivarColormap(ColormapBase):
1443+
class BivarColormap:
14131444
"""
14141445
Baseclass for all bivarate to RGBA mappings.
14151446

lib/matplotlib/colors.pyi

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,7 @@ class ColorConverter:
6464

6565
colorConverter: ColorConverter
6666

67-
class ColormapBase:
68-
...
69-
70-
class Colormap(ColormapBase):
67+
class Colormap:
7168
name: str
7269
N: int
7370
colorbar_extend: bool
@@ -141,7 +138,7 @@ class ListedColormap(Colormap):
141138
def resampled(self, lutsize: int) -> ListedColormap: ...
142139
def reversed(self, name: str | None = ...) -> ListedColormap: ...
143140

144-
class MultivarColormap(ColormapBase):
141+
class MultivarColormap:
145142
name: str
146143
colormaps: list[Colormap]
147144
combination_mode: str
@@ -166,8 +163,10 @@ class MultivarColormap(ColormapBase):
166163
def __len__(self) -> int: ...
167164
def get_bad(self) -> np.ndarray: ...
168165
def set_bad(self, color: ColorType = ..., alpha: float | None = ...) -> None: ...
166+
def _repr_html_(self) -> str: ...
167+
def _repr_png_(self) -> bytes: ...
169168

170-
class BivarColormap(ColormapBase):
169+
class BivarColormap:
171170
name: str
172171
N: int
173172
M: int
@@ -190,12 +189,15 @@ class BivarColormap(ColormapBase):
190189
@property
191190
def lut(self) -> np.ndarray: ...
192191
def __copy__(self) -> BivarColormap: ...
192+
def __getitem__(self, item: int) -> Colormap: ...
193193
def __eq__(self, other) -> bool: ...
194194
def get_bad(self) -> np.ndarray: ...
195195
def set_bad(self, color: ColorType = ..., alpha: float | None = ...) -> None: ...
196196
def get_outside(self) -> np.ndarray: ...
197197
def set_outside(self, color: ColorType = ..., alpha: float | None = ...) -> None: ...
198198
def copy(self) -> BivarColormap: ...
199+
def _repr_html_(self) -> str: ...
200+
def _repr_png_(self) -> bytes: ...
199201

200202
class SegmentedBivarColormap(BivarColormap):
201203
def __init__(

lib/matplotlib/tests/test_multivariate_colormaps.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def test_bivariate_cmap_shapes():
4141
def test_multivar_creation():
4242
# test creation of a custom multivariate colorbar
4343
blues = mpl.colormaps['Blues']
44-
oranges = mpl.colormaps['Oranges']
45-
cmap = mpl.colors.MultivarColormap('custom', (blues, oranges), 'Sub')
44+
cmap = mpl.colors.MultivarColormap('custom', (blues, 'Oranges'), 'Sub')
4645
y, x = np.mgrid[0:3, 0:3]/2
4746
im = cmap((y, x))
4847
res = np.array([[[0.96862745, 0.94509804, 0.92156863, 1],
@@ -57,8 +56,8 @@ def test_multivar_creation():
5756
assert_allclose(im, res, atol=0.01)
5857

5958
with pytest.raises(ValueError, match="colormaps must be a list of"):
60-
cmap = mpl.colors.MultivarColormap('custom', (blues, 'Oranges'), 'Sub')
61-
with pytest.raises(ValueError, match="colormaps must be a list of"):
59+
cmap = mpl.colors.MultivarColormap('custom', (blues, [blues]), 'Sub')
60+
with pytest.raises(ValueError, match="A MultivarColormap must"):
6261
cmap = mpl.colors.MultivarColormap('custom', 'blues', 'Sub')
6362
with pytest.raises(ValueError, match="A MultivarColormap must"):
6463
cmap = mpl.colors.MultivarColormap('custom', (blues), 'Sub')
@@ -281,8 +280,9 @@ def test_bivar_cmap_call():
281280
match="only implemented for use with with floats"):
282281
cs = cmap([(0, 5, 9, 0, 0, 9), (0, 0, 0, 5, 11, 11)])
283282

283+
284284
def test_bivar_getitem():
285-
'''Test __getitem__ on BivarColormap'''
285+
"""Test __getitem__ on BivarColormap"""
286286
xA = ([.0, .25, .5, .75, 1., -1, 2], [.5]*7)
287287
xB = ([.5]*7, [.0, .25, .5, .75, 1., -1, 2])
288288

@@ -304,7 +304,7 @@ def test_bivar_getitem():
304304
assert_array_equal(cmaps(xA), cmaps[0](xA[0]))
305305
assert_array_equal(cmaps(xB), cmaps[1](xB[1]))
306306

307-
307+
308308
def test_bivar_cmap_bad_shape():
309309
"""
310310
Tests calling a bivariate colormap with integer values

0 commit comments

Comments
 (0)