Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX: Regression in DecisionBoundaryDisplay.from_estimator with colors and plot_method='contour' #31553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 62 additions & 46 deletions sklearn/inspection/_plot/decision_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,60 +221,76 @@
self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
else: # self.response.ndim == 3
n_responses = self.response.shape[-1]
if (
isinstance(self.multiclass_colors, str)
or self.multiclass_colors is None
):
if isinstance(self.multiclass_colors, str):
cmap = self.multiclass_colors
else:
if n_responses <= 10:
cmap = "tab10"
if self.multiclass_colors is None:
if "cmap" in kwargs and "colors" in kwargs:
raise ValueError(
"Cannot specify both 'cmap' and 'colors' in kwargs. "
"Please use only one of them."
Comment on lines +227 to +228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't one deprecated upstream? We should recommend the new way.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!
To clarify: matplotlib hasn't deprecated either cmap or colors — both are still valid and serve different purposes.
That said, I agree it's better to recommend cmap as the preferred option, especially for continuous or gradient-based plots. I’ll update the error message accordingly to reflect that.

)
if "cmap" in kwargs:
cmap = plt.get_cmap(kwargs.pop("cmap"), n_responses)
if not hasattr(cmap, "colors"):
# For LinearSegmentedColormap
colors = cmap(np.linspace(0, 1, n_responses))
else:
cmap = "gist_rainbow"

# Special case for the tab10 and tab20 colormaps that encode a
# discrete set of colors that are easily distinguishable
# contrary to other colormaps that are continuous.
if cmap == "tab10" and n_responses <= 10:
colors = plt.get_cmap("tab10", 10).colors[:n_responses]
elif cmap == "tab20" and n_responses <= 20:
colors = plt.get_cmap("tab20", 20).colors[:n_responses]
colors = cmap.colors
elif "colors" in kwargs:
if isinstance(kwargs["colors"], str):
colors = mpl.colors.to_rgba(kwargs.pop("colors"))
colors = [colors for _ in range(n_responses)]
else:
colors = kwargs.pop("colors")
colors = [mpl.colors.to_rgba(color) for color in colors]
else:
cmap = "tab10" if n_responses <= 10 else "gist_rainbow"
colors = plt.get_cmap(cmap, n_responses).colors
elif isinstance(self.multiclass_colors, str):
colors = colors = plt.get_cmap(
self.multiclass_colors, n_responses
).colors
else:
colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors]
if "cmap" in kwargs:
warnings.warn("'cmap' is ignored when 'multiclass_colors' is set.")
del kwargs["cmap"]
if "colors" in kwargs:
warnings.warn(
"'colors' is ignored when 'multiclass_colors' is set."
)
del kwargs["colors"]
if isinstance(self.multiclass_colors, str):
cmap = plt.get_cmap(self.multiclass_colors, n_responses)
if not hasattr(cmap, "colors"):
# For LinearSegmentedColormap
colors = cmap(np.linspace(0, 1, n_responses))

Check warning on line 260 in sklearn/inspection/_plot/decision_boundary.py

View check run for this annotation

Codecov / codecov/patch

sklearn/inspection/_plot/decision_boundary.py#L260

Added line #L260 was not covered by tests
else:
colors = cmap.colors
elif isinstance(self.multiclass_colors, list):
colors = [
mpl.colors.to_rgba(color) for color in self.multiclass_colors
]
else:
raise ValueError("'multiclass_colors' must be a list or a str.")

Check warning on line 268 in sklearn/inspection/_plot/decision_boundary.py

View check run for this annotation

Codecov / codecov/patch

sklearn/inspection/_plot/decision_boundary.py#L268

Added line #L268 was not covered by tests

self.multiclass_colors_ = colors
multiclass_cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
if plot_method == "contour":
# Plot only argmax map for contour
class_map = self.response.argmax(axis=2)
self.surface_ = plot_func(
self.xx0, self.xx1, class_map, colors=colors, **kwargs
)
for class_idx, (r, g, b, _) in enumerate(colors)
]

self.surface_ = []
for class_idx, cmap in enumerate(multiclass_cmaps):
response = np.ma.array(
self.response[:, :, class_idx],
mask=~(self.response.argmax(axis=2) == class_idx),
)
# `cmap` should not be in kwargs
safe_kwargs = kwargs.copy()
if "cmap" in safe_kwargs:
del safe_kwargs["cmap"]
warnings.warn(
"Plotting max class of multiclass 'decision_function' or "
"'predict_proba', thus 'multiclass_colors' used and "
"'cmap' kwarg ignored."
else:
multiclass_cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]

self.surface_ = []
for class_idx, cmap in enumerate(multiclass_cmaps):
response = np.ma.array(
self.response[:, :, class_idx],
mask=~(self.response.argmax(axis=2) == class_idx),
)
self.surface_.append(
plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs)
)
self.surface_.append(
plot_func(self.xx0, self.xx1, response, cmap=cmap, **safe_kwargs)
)

if xlabel is not None or not ax.get_xlabel():
xlabel = self.xlabel if xlabel is None else xlabel
Expand Down
127 changes: 113 additions & 14 deletions sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,28 +645,127 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors):
else:
colors = [mpl.colors.to_rgba(color) for color in multiclass_colors]

cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
if plot_method != "contour":
cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]
for idx, quad in enumerate(disp.surface_):
assert quad.cmap == cmaps[idx]
else:
assert_allclose(disp.surface_.colors, colors)


def test_cmap_and_colors_logic(pyplot):
"""Check the handling logic for `cmap` and `colors`."""
X, y = load_iris_2d_scaled()
clf = LogisticRegression().fit(X, y)

with pytest.raises(
ValueError,
match="Cannot specify both 'cmap' and 'colors' in kwargs.",
):
DecisionBoundaryDisplay.from_estimator(
clf,
X,
colors="black",
cmap="Blues",
)
for class_idx, (r, g, b, _) in enumerate(colors)
]

for idx, quad in enumerate(disp.surface_):
assert quad.cmap == cmaps[idx]
with pytest.warns(
UserWarning,
match="'cmap' is ignored when 'multiclass_colors' is set.",
):
DecisionBoundaryDisplay.from_estimator(
clf,
X,
multiclass_colors="plasma",
cmap="Blues",
)

with pytest.warns(
UserWarning,
match="'colors' is ignored when 'multiclass_colors' is set.",
):
DecisionBoundaryDisplay.from_estimator(
clf,
X,
multiclass_colors="plasma",
colors="blue",
)


@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"])
@pytest.mark.parametrize("kwargs", [{"cmap": "tab10"}, {"cmap": "Blues"}])
def test_multiclass_cmap(pyplot, plot_method, kwargs):
"""Check that `cmap` is correctly applied to DecisionBoundaryDisplay."""
import matplotlib as mpl
import matplotlib.pyplot as plt

def test_multiclass_plot_max_class_cmap_kwarg(pyplot):
"""Check `cmap` kwarg ignored when using plotting max multiclass class."""
X, y = load_iris_2d_scaled()
clf = LogisticRegression().fit(X, y)

msg = (
"Plotting max class of multiclass 'decision_function' or 'predict_proba', "
"thus 'multiclass_colors' used and 'cmap' kwarg ignored."
disp = DecisionBoundaryDisplay.from_estimator(
clf,
X,
cmap=kwargs["cmap"],
plot_method=plot_method,
)

cmap = plt.get_cmap(kwargs["cmap"], len(clf.classes_))
if not hasattr(cmap, "colors"):
colors = cmap(np.linspace(0, 1, len(clf.classes_)))
else:
colors = cmap.colors

if plot_method != "contour":
cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]
for idx, quad in enumerate(disp.surface_):
assert quad.cmap == cmaps[idx]
else:
assert_allclose(disp.surface_.colors, colors)


@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"])
@pytest.mark.parametrize("kwargs", [{"colors": "black"}, {"colors": ["r", "g", "b"]}])
def test_multiclass_colors(pyplot, plot_method, kwargs):
"""Check that `cmap` is correctly applied to DecisionBoundaryDisplay."""
import matplotlib as mpl

X, y = load_iris_2d_scaled()
clf = LogisticRegression().fit(X, y)

disp = DecisionBoundaryDisplay.from_estimator(
clf,
X,
colors=kwargs["colors"],
plot_method=plot_method,
)
with pytest.warns(UserWarning, match=msg):
DecisionBoundaryDisplay.from_estimator(clf, X, cmap="viridis")

if isinstance(kwargs["colors"], str):
colors = mpl.colors.to_rgba(kwargs["colors"])
colors = [colors for _ in range(len(clf.classes_))]
else:
colors = [mpl.colors.to_rgba(color) for color in kwargs["colors"]]

if plot_method != "contour":
cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]
for idx, quad in enumerate(disp.surface_):
assert quad.cmap == cmaps[idx]
else:
assert_allclose(disp.surface_.colors, colors)


def test_subclass_named_constructors_return_type_is_subclass(pyplot):
Expand Down
Loading