From 08adb0c80cc370e0d39fd281c5dcae3a34a1a3c0 Mon Sep 17 00:00:00 2001 From: Peter Liu Date: Mon, 23 Nov 2020 23:14:59 -0600 Subject: [PATCH 1/3] add colormodel parameter into px.imshow and test --- .../python/plotly/plotly/express/_imshow.py | 13 +++++++++++- .../tests/test_core/test_px/test_imshow.py | 20 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 88713e54368..1530c13128b 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -54,6 +54,7 @@ def _infer_zmax_from_type(img): def imshow( img, + colormodel=None, zmin=None, zmax=None, origin=None, @@ -88,6 +89,11 @@ def imshow( - (M, N, 3): an image with RGB values. - (M, N, 4): an image with RGBA values, i.e. including transparency. + colormodel: str, 'rgb' or 'hsl' (default 'rgb') + colormodel used to map the numerical color components into colors. + It enables automatic change from rgb/hsl to rgba/hsla depending on + img dimensions. + zmin, zmax : scalar or iterable, optional zmin and zmax define the scalar range that the colormap covers. By default, zmin and zmax correspond to the min and max values of the datatype for integer @@ -389,7 +395,12 @@ def imshow( ) trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy) else: - colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" + # change colormodel from hsl/rgba to hsla/rgba depending on dimensions + if colormodel == "hsl": + if img.shape[-1] == 4: + colormodel = "hsla" + else: + colormodel = "rgb" if img.shape[-1] == 3 else "rgba" trace = go.Image( z=img, zmin=zmin, diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 313267aacbd..8bbe87b65fa 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -346,3 +346,23 @@ def test_imshow_hovertemplate(binary_string): fig.data[0].hovertemplate == "x: %{x}
y: %{y}
color: %{z}" ) + + +@pytest.mark.parametrize("colormodel", ["rgb", "hsl"]) +def test_imshow_colormodel(colormodel): + img_no_alpha = np.array([[[147, 50, 47], [147, 50, 47]], + [[147, 50, 47], [147, 50, 47]]], dtype=np.uint8) + img_alpha = np.array([[[147, 50, 47, 0.1], [147, 50, 47, 0.2]], + [[147, 50, 47, 0.3], [147, 50, 47, 1]]], dtype=np.uint8) + for img in [img_no_alpha, img_alpha]: + fig = px.imshow( + img, + colormodel=colormodel + ) + if img.shape[2] == 3: + assert decode_image_string(fig.data[0].source).shape[2] == 3 + assert np.all(img_no_alpha == decode_image_string(fig.data[0].source)) + else: + assert decode_image_string(fig.data[0].source).shape[2] == 4 + assert np.all(img_alpha == decode_image_string(fig.data[0].source)) + From de4826559039b427b067860787285ba09929ba2a Mon Sep 17 00:00:00 2001 From: Peter Liu Date: Tue, 24 Nov 2020 09:20:21 -0600 Subject: [PATCH 2/3] update format of test_imshow using black --- .../tests/test_core/test_px/test_imshow.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 8bbe87b65fa..acd3518d225 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -183,7 +183,9 @@ def test_imshow_xarray(): def test_imshow_labels_and_ranges(): - fig = px.imshow([[1, 2], [3, 4], [5, 6]],) + fig = px.imshow( + [[1, 2], [3, 4], [5, 6]], + ) assert fig.layout.xaxis.title.text is None assert fig.layout.yaxis.title.text is None assert fig.layout.coloraxis.colorbar.title.text is None @@ -350,19 +352,21 @@ def test_imshow_hovertemplate(binary_string): @pytest.mark.parametrize("colormodel", ["rgb", "hsl"]) def test_imshow_colormodel(colormodel): - img_no_alpha = np.array([[[147, 50, 47], [147, 50, 47]], - [[147, 50, 47], [147, 50, 47]]], dtype=np.uint8) - img_alpha = np.array([[[147, 50, 47, 0.1], [147, 50, 47, 0.2]], - [[147, 50, 47, 0.3], [147, 50, 47, 1]]], dtype=np.uint8) - for img in [img_no_alpha, img_alpha]: - fig = px.imshow( - img, - colormodel=colormodel - ) + img_no_alpha = np.array( + [[[147, 50, 47], [147, 50, 47]], [[147, 50, 47], [147, 50, 47]]], dtype=np.uint8 + ) + img_alpha = np.array( + [ + [[147, 50, 47, 0.1], [147, 50, 47, 0.2]], + [[147, 50, 47, 0.3], [147, 50, 47, 1]], + ], + dtype=np.uint8, + ) + for img in [img_no_alpha, img_alpha]: + fig = px.imshow(img, colormodel=colormodel) if img.shape[2] == 3: assert decode_image_string(fig.data[0].source).shape[2] == 3 assert np.all(img_no_alpha == decode_image_string(fig.data[0].source)) else: assert decode_image_string(fig.data[0].source).shape[2] == 4 assert np.all(img_alpha == decode_image_string(fig.data[0].source)) - From d1d8bbc3c817953fadb0ccbe0d709f8e7a07400a Mon Sep 17 00:00:00 2001 From: Peter Liu Date: Tue, 24 Nov 2020 09:36:08 -0600 Subject: [PATCH 3/3] use black==19.10b0 to reformat --- .../plotly/plotly/tests/test_core/test_px/test_imshow.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index acd3518d225..01c0d11a16d 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -183,9 +183,7 @@ def test_imshow_xarray(): def test_imshow_labels_and_ranges(): - fig = px.imshow( - [[1, 2], [3, 4], [5, 6]], - ) + fig = px.imshow([[1, 2], [3, 4], [5, 6]],) assert fig.layout.xaxis.title.text is None assert fig.layout.yaxis.title.text is None assert fig.layout.coloraxis.colorbar.title.text is None