diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index bc28708d7c488..98c6d792622d3 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -221,60 +221,76 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) else: # self.response.ndim == 3 n_responses = self.response.shape[-1] - if ( - isinstance(self.multiclass_colors, str) - or self.multiclass_colors is None - ): - if isinstance(self.multiclass_colors, str): - cmap = self.multiclass_colors - else: - if n_responses <= 10: - cmap = "tab10" + if self.multiclass_colors is None: + if "cmap" in kwargs and "colors" in kwargs: + raise ValueError( + "Cannot specify both 'cmap' and 'colors' in kwargs. " + "Please use only one of them." + ) + if "cmap" in kwargs: + cmap = plt.get_cmap(kwargs.pop("cmap"), n_responses) + if not hasattr(cmap, "colors"): + # For LinearSegmentedColormap + colors = cmap(np.linspace(0, 1, n_responses)) else: - cmap = "gist_rainbow" - - # Special case for the tab10 and tab20 colormaps that encode a - # discrete set of colors that are easily distinguishable - # contrary to other colormaps that are continuous. - if cmap == "tab10" and n_responses <= 10: - colors = plt.get_cmap("tab10", 10).colors[:n_responses] - elif cmap == "tab20" and n_responses <= 20: - colors = plt.get_cmap("tab20", 20).colors[:n_responses] + colors = cmap.colors + elif "colors" in kwargs: + if isinstance(kwargs["colors"], str): + colors = mpl.colors.to_rgba(kwargs.pop("colors")) + colors = [colors for _ in range(n_responses)] + else: + colors = kwargs.pop("colors") + colors = [mpl.colors.to_rgba(color) for color in colors] else: + cmap = "tab10" if n_responses <= 10 else "gist_rainbow" colors = plt.get_cmap(cmap, n_responses).colors - elif isinstance(self.multiclass_colors, str): - colors = colors = plt.get_cmap( - self.multiclass_colors, n_responses - ).colors else: - colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors] + if "cmap" in kwargs: + warnings.warn("'cmap' is ignored when 'multiclass_colors' is set.") + del kwargs["cmap"] + if "colors" in kwargs: + warnings.warn( + "'colors' is ignored when 'multiclass_colors' is set." + ) + del kwargs["colors"] + if isinstance(self.multiclass_colors, str): + cmap = plt.get_cmap(self.multiclass_colors, n_responses) + if not hasattr(cmap, "colors"): + # For LinearSegmentedColormap + colors = cmap(np.linspace(0, 1, n_responses)) + else: + colors = cmap.colors + elif isinstance(self.multiclass_colors, list): + colors = [ + mpl.colors.to_rgba(color) for color in self.multiclass_colors + ] + else: + raise ValueError("'multiclass_colors' must be a list or a str.") self.multiclass_colors_ = colors - multiclass_cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + if plot_method == "contour": + # Plot only argmax map for contour + class_map = self.response.argmax(axis=2) + self.surface_ = plot_func( + self.xx0, self.xx1, class_map, colors=colors, **kwargs ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] - - self.surface_ = [] - for class_idx, cmap in enumerate(multiclass_cmaps): - response = np.ma.array( - self.response[:, :, class_idx], - mask=~(self.response.argmax(axis=2) == class_idx), - ) - # `cmap` should not be in kwargs - safe_kwargs = kwargs.copy() - if "cmap" in safe_kwargs: - del safe_kwargs["cmap"] - warnings.warn( - "Plotting max class of multiclass 'decision_function' or " - "'predict_proba', thus 'multiclass_colors' used and " - "'cmap' kwarg ignored." + else: + multiclass_cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + + self.surface_ = [] + for class_idx, cmap in enumerate(multiclass_cmaps): + response = np.ma.array( + self.response[:, :, class_idx], + mask=~(self.response.argmax(axis=2) == class_idx), + ) + self.surface_.append( + plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs) ) - self.surface_.append( - plot_func(self.xx0, self.xx1, response, cmap=cmap, **safe_kwargs) - ) if xlabel is not None or not ax.get_xlabel(): xlabel = self.xlabel if xlabel is None else xlabel diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 3284f42241fa5..109b8a66e2114 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -645,28 +645,127 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): else: colors = [mpl.colors.to_rgba(color) for color in multiclass_colors] - cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) + + +def test_cmap_and_colors_logic(pyplot): + """Check the handling logic for `cmap` and `colors`.""" + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + with pytest.raises( + ValueError, + match="Cannot specify both 'cmap' and 'colors' in kwargs.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + colors="black", + cmap="Blues", ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] - for idx, quad in enumerate(disp.surface_): - assert quad.cmap == cmaps[idx] + with pytest.warns( + UserWarning, + match="'cmap' is ignored when 'multiclass_colors' is set.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + multiclass_colors="plasma", + cmap="Blues", + ) + with pytest.warns( + UserWarning, + match="'colors' is ignored when 'multiclass_colors' is set.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + multiclass_colors="plasma", + colors="blue", + ) + + +@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) +@pytest.mark.parametrize("kwargs", [{"cmap": "tab10"}, {"cmap": "Blues"}]) +def test_multiclass_cmap(pyplot, plot_method, kwargs): + """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" + import matplotlib as mpl + import matplotlib.pyplot as plt -def test_multiclass_plot_max_class_cmap_kwarg(pyplot): - """Check `cmap` kwarg ignored when using plotting max multiclass class.""" X, y = load_iris_2d_scaled() clf = LogisticRegression().fit(X, y) - msg = ( - "Plotting max class of multiclass 'decision_function' or 'predict_proba', " - "thus 'multiclass_colors' used and 'cmap' kwarg ignored." + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=kwargs["cmap"], + plot_method=plot_method, + ) + + cmap = plt.get_cmap(kwargs["cmap"], len(clf.classes_)) + if not hasattr(cmap, "colors"): + colors = cmap(np.linspace(0, 1, len(clf.classes_))) + else: + colors = cmap.colors + + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) + + +@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) +@pytest.mark.parametrize("kwargs", [{"colors": "black"}, {"colors": ["r", "g", "b"]}]) +def test_multiclass_colors(pyplot, plot_method, kwargs): + """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" + import matplotlib as mpl + + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + colors=kwargs["colors"], + plot_method=plot_method, ) - with pytest.warns(UserWarning, match=msg): - DecisionBoundaryDisplay.from_estimator(clf, X, cmap="viridis") + + if isinstance(kwargs["colors"], str): + colors = mpl.colors.to_rgba(kwargs["colors"]) + colors = [colors for _ in range(len(clf.classes_))] + else: + colors = [mpl.colors.to_rgba(color) for color in kwargs["colors"]] + + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) def test_subclass_named_constructors_return_type_is_subclass(pyplot):