From 1afc804b71ae5f68e6d81f4ca6ac4ed51f01a08e Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Mon, 16 Jun 2025 13:20:40 +0800 Subject: [PATCH 1/5] Fix: DecisionBoundaryDisplay for `contour` and `cmap` (issue scikit-learn#31546 --- sklearn/inspection/_plot/decision_boundary.py | 104 ++++++++++-------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index bc28708d7c488..7922907854a11 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -221,60 +221,72 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) else: # self.response.ndim == 3 n_responses = self.response.shape[-1] - if ( - isinstance(self.multiclass_colors, str) - or self.multiclass_colors is None - ): - if isinstance(self.multiclass_colors, str): - cmap = self.multiclass_colors - else: - if n_responses <= 10: - cmap = "tab10" + if self.multiclass_colors is None: + if "cmap" in kwargs and "colors" in kwargs: + raise ValueError( + "Cannot specify both 'cmap' and 'colors' in kwargs. " + "Please use only one of them." + ) + if "cmap" in kwargs: + cmap = plt.get_cmap(kwargs.pop("cmap"), n_responses) + if not hasattr(cmap, "colors"): + # For LinearSegmentedColormap + colors = cmap(np.linspace(0, 1, n_responses)) else: - cmap = "gist_rainbow" - - # Special case for the tab10 and tab20 colormaps that encode a - # discrete set of colors that are easily distinguishable - # contrary to other colormaps that are continuous. - if cmap == "tab10" and n_responses <= 10: - colors = plt.get_cmap("tab10", 10).colors[:n_responses] - elif cmap == "tab20" and n_responses <= 20: - colors = plt.get_cmap("tab20", 20).colors[:n_responses] + colors = cmap.colors + elif "colors" in kwargs: + colors = mpl.colors.to_rgba(kwargs.pop("colors")) + colors = [colors for _ in range(n_responses)] else: + cmap = "tab10" if n_responses <= 10 else "gist_rainbow" colors = plt.get_cmap(cmap, n_responses).colors - elif isinstance(self.multiclass_colors, str): - colors = colors = plt.get_cmap( - self.multiclass_colors, n_responses - ).colors else: - colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors] + if "cmap" in kwargs: + warnings.warn("'cmap' is ignored when 'multiclass_colors' is set.") + del kwargs["cmap"] + if "colors" in kwargs: + warnings.warn( + "'colors' is ignored when 'multiclass_colors' is set." + ) + del kwargs["colors"] + if isinstance(self.multiclass_colors, str): + cmap = plt.get_cmap(self.multiclass_colors, n_responses) + if not hasattr(cmap, "colors"): + # For LinearSegmentedColormap + colors = cmap(np.linspace(0, 1, n_responses)) + else: + colors = cmap.colors + elif isinstance(self.multiclass_colors, list): + colors = [ + mpl.colors.to_rgba(color) for color in self.multiclass_colors + ] + else: + raise ValueError("'multiclass_colors' must be a list or a str.") self.multiclass_colors_ = colors - multiclass_cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] - ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] - - self.surface_ = [] - for class_idx, cmap in enumerate(multiclass_cmaps): - response = np.ma.array( - self.response[:, :, class_idx], - mask=~(self.response.argmax(axis=2) == class_idx), + if plot_method == "contour": + # Plot only argmax map for contour + class_map = self.response.argmax(axis=2) + self.surface_ = plot_func( + self.xx0, self.xx1, class_map, colors=colors, **kwargs ) - # `cmap` should not be in kwargs - safe_kwargs = kwargs.copy() - if "cmap" in safe_kwargs: - del safe_kwargs["cmap"] - warnings.warn( - "Plotting max class of multiclass 'decision_function' or " - "'predict_proba', thus 'multiclass_colors' used and " - "'cmap' kwarg ignored." + else: + multiclass_cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + + self.surface_ = [] + for class_idx, cmap in enumerate(multiclass_cmaps): + response = np.ma.array( + self.response[:, :, class_idx], + mask=~(self.response.argmax(axis=2) == class_idx), + ) + self.surface_.append( + plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs) ) - self.surface_.append( - plot_func(self.xx0, self.xx1, response, cmap=cmap, **safe_kwargs) - ) if xlabel is not None or not ax.get_xlabel(): xlabel = self.xlabel if xlabel is None else xlabel From a80ef4c62530229bc4686953a312ab3f4fd5e428 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Mon, 16 Jun 2025 13:21:40 +0800 Subject: [PATCH 2/5] Update corresponding tests --- .../tests/test_boundary_decision_display.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 3284f42241fa5..06943d41bd4ec 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -652,21 +652,9 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): for class_idx, (r, g, b, _) in enumerate(colors) ] - for idx, quad in enumerate(disp.surface_): - assert quad.cmap == cmaps[idx] - - -def test_multiclass_plot_max_class_cmap_kwarg(pyplot): - """Check `cmap` kwarg ignored when using plotting max multiclass class.""" - X, y = load_iris_2d_scaled() - clf = LogisticRegression().fit(X, y) - - msg = ( - "Plotting max class of multiclass 'decision_function' or 'predict_proba', " - "thus 'multiclass_colors' used and 'cmap' kwarg ignored." - ) - with pytest.warns(UserWarning, match=msg): - DecisionBoundaryDisplay.from_estimator(clf, X, cmap="viridis") + if plot_method != 'contour': + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] def test_subclass_named_constructors_return_type_is_subclass(pyplot): From c72e425d3d3b0e6f4d4090dad4cf4df9080fcfd9 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Mon, 16 Jun 2025 13:55:36 +0800 Subject: [PATCH 3/5] Reformat with ruff --- .../inspection/_plot/tests/test_boundary_decision_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 06943d41bd4ec..9e4b7bf45c12e 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -652,7 +652,7 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): for class_idx, (r, g, b, _) in enumerate(colors) ] - if plot_method != 'contour': + if plot_method != "contour": for idx, quad in enumerate(disp.surface_): assert quad.cmap == cmaps[idx] From 1be1a98845dd162853e9bb2183cb5d1b4438c274 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Sat, 21 Jun 2025 22:54:24 +0800 Subject: [PATCH 4/5] Fix: add handling logic when `colors` is a list --- sklearn/inspection/_plot/decision_boundary.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 7922907854a11..98c6d792622d3 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -235,8 +235,12 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar else: colors = cmap.colors elif "colors" in kwargs: - colors = mpl.colors.to_rgba(kwargs.pop("colors")) - colors = [colors for _ in range(n_responses)] + if isinstance(kwargs["colors"], str): + colors = mpl.colors.to_rgba(kwargs.pop("colors")) + colors = [colors for _ in range(n_responses)] + else: + colors = kwargs.pop("colors") + colors = [mpl.colors.to_rgba(color) for color in colors] else: cmap = "tab10" if n_responses <= 10 else "gist_rainbow" colors = plt.get_cmap(cmap, n_responses).colors From 7ca868f67a6fb23f35d33554f177f98af80a4f53 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Sat, 21 Jun 2025 22:56:35 +0800 Subject: [PATCH 5/5] ENH: add tests for cmap and colors handling in DecisionBoundaryDisplay --- .../tests/test_boundary_decision_display.py | 121 +++++++++++++++++- 1 file changed, 116 insertions(+), 5 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 9e4b7bf45c12e..109b8a66e2114 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -645,16 +645,127 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): else: colors = [mpl.colors.to_rgba(color) for color in multiclass_colors] - cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) + + +def test_cmap_and_colors_logic(pyplot): + """Check the handling logic for `cmap` and `colors`.""" + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + with pytest.raises( + ValueError, + match="Cannot specify both 'cmap' and 'colors' in kwargs.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + colors="black", + cmap="Blues", + ) + + with pytest.warns( + UserWarning, + match="'cmap' is ignored when 'multiclass_colors' is set.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + multiclass_colors="plasma", + cmap="Blues", ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] + + with pytest.warns( + UserWarning, + match="'colors' is ignored when 'multiclass_colors' is set.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + multiclass_colors="plasma", + colors="blue", + ) + + +@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) +@pytest.mark.parametrize("kwargs", [{"cmap": "tab10"}, {"cmap": "Blues"}]) +def test_multiclass_cmap(pyplot, plot_method, kwargs): + """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" + import matplotlib as mpl + import matplotlib.pyplot as plt + + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=kwargs["cmap"], + plot_method=plot_method, + ) + + cmap = plt.get_cmap(kwargs["cmap"], len(clf.classes_)) + if not hasattr(cmap, "colors"): + colors = cmap(np.linspace(0, 1, len(clf.classes_))) + else: + colors = cmap.colors if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] for idx, quad in enumerate(disp.surface_): assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) + + +@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) +@pytest.mark.parametrize("kwargs", [{"colors": "black"}, {"colors": ["r", "g", "b"]}]) +def test_multiclass_colors(pyplot, plot_method, kwargs): + """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" + import matplotlib as mpl + + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + colors=kwargs["colors"], + plot_method=plot_method, + ) + + if isinstance(kwargs["colors"], str): + colors = mpl.colors.to_rgba(kwargs["colors"]) + colors = [colors for _ in range(len(clf.classes_))] + else: + colors = [mpl.colors.to_rgba(color) for color in kwargs["colors"]] + + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) def test_subclass_named_constructors_return_type_is_subclass(pyplot):