diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 88713e54368..1530c13128b 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -54,6 +54,7 @@ def _infer_zmax_from_type(img): def imshow( img, + colormodel=None, zmin=None, zmax=None, origin=None, @@ -88,6 +89,11 @@ def imshow( - (M, N, 3): an image with RGB values. - (M, N, 4): an image with RGBA values, i.e. including transparency. + colormodel: str, 'rgb' or 'hsl' (default 'rgb') + colormodel used to map the numerical color components into colors. + It enables automatic change from rgb/hsl to rgba/hsla depending on + img dimensions. + zmin, zmax : scalar or iterable, optional zmin and zmax define the scalar range that the colormap covers. By default, zmin and zmax correspond to the min and max values of the datatype for integer @@ -389,7 +395,12 @@ def imshow( ) trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy) else: - colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" + # change colormodel from hsl/rgba to hsla/rgba depending on dimensions + if colormodel == "hsl": + if img.shape[-1] == 4: + colormodel = "hsla" + else: + colormodel = "rgb" if img.shape[-1] == 3 else "rgba" trace = go.Image( z=img, zmin=zmin, diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 313267aacbd..01c0d11a16d 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -346,3 +346,25 @@ def test_imshow_hovertemplate(binary_string): fig.data[0].hovertemplate == "x: %{x}
y: %{y}
color: %{z}" ) + + +@pytest.mark.parametrize("colormodel", ["rgb", "hsl"]) +def test_imshow_colormodel(colormodel): + img_no_alpha = np.array( + [[[147, 50, 47], [147, 50, 47]], [[147, 50, 47], [147, 50, 47]]], dtype=np.uint8 + ) + img_alpha = np.array( + [ + [[147, 50, 47, 0.1], [147, 50, 47, 0.2]], + [[147, 50, 47, 0.3], [147, 50, 47, 1]], + ], + dtype=np.uint8, + ) + for img in [img_no_alpha, img_alpha]: + fig = px.imshow(img, colormodel=colormodel) + if img.shape[2] == 3: + assert decode_image_string(fig.data[0].source).shape[2] == 3 + assert np.all(img_no_alpha == decode_image_string(fig.data[0].source)) + else: + assert decode_image_string(fig.data[0].source).shape[2] == 4 + assert np.all(img_alpha == decode_image_string(fig.data[0].source))