-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH add from_cv_results
in PrecisionRecallDisplay
(single Display)
#30508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot to mention, I think I would like to decide on the order parameters for these display classes and their methods. They seem to have a lot of overlap and it would be great if they could be consistent.
I know that this would not matter when using the methods but it would be nice for the documentation API page if they were consistent?
|
||
estimator_name : str, default=None | ||
Name of estimator. If None, then the estimator name is not shown. | ||
curve_name : str or list of str, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought curve_name
is a more generalizable term (vs estimator_name
), especially with cv multi curves where we want to name each curve by the fold number.
Changing this name will mean that we must change _validate_plot_params
and thus all other classes that use _BinaryClassifierCurveDisplayMixin
I note that the parameter is named differently here (PrecisionRecallDisplay
init) vs in the from_prediction
and from_estimator
methods (where it's called name
). I'm not sure if this was accidental or to distinguish it from the method parameter 'name's?
# If multi-curve, ensure all args are of the right length | ||
req_multi = [ | ||
input for input in (self.precision, self.recall) if isinstance(input, list) | ||
] | ||
if req_multi and ((len(req_multi) != 2) or len({len(arg) for arg in req_multi}) > 1): | ||
raise ValueError( | ||
"When plotting multiple precision-recall curves, `self.precision` " | ||
"and `self.recall` should both be lists of the same length." | ||
) | ||
elif self.average_precision is not None: | ||
default_line_kwargs["label"] = f"AP = {self.average_precision:0.2f}" | ||
elif name is not None: | ||
default_line_kwargs["label"] = name | ||
n_multi = len(self.precision) if req_multi else None | ||
if req_multi: | ||
for name, param in zip( | ||
["self.average_precision", "`name` or `self.curve_name`"], | ||
(self.average_precision, name_) | ||
): | ||
if not((isinstance(param, list) and len(param) != n_multi) or param is not None): | ||
raise ValueError( | ||
f"For multi precision-recall curves, {name} must either be " | ||
"a list of the same length as `self.precision` and " | ||
"`self.recall`, or None." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I struggled to come up with a nice way to do this. The checks we need are:
precision
andrecall
, both need to be lists of the same length or both need to be single ndarray- for multi curve,
average_precision
andname
can either be a list of the same length or None.
This latter point is important, as previously I simply checked that all 4 parameters are of the same length if they were lists. I didn't check that 2 optional parameters needed to be None
if they were not a list, for the multi-curve situation.
Suggestions welcome for making this nicer.
The good part though is that this is easily factorized out and can be generalised for all similar displays.
name_ = [name_] * n_multi if name_ is None else name_ | ||
average_precision_ = ( | ||
[None] * n_multi if self.average_precision is None else self.average_precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this, but could not immediately think of a better way to do it
) | ||
# Note `pos_label` cannot be `None` (default=1), unlike other metrics | ||
# such as roc_auc | ||
average_precision = average_precision_score( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note pos_label
cannot be None
here (default=1), unlike other metrics as roc_auc
precision_all.append(precision) | ||
recall_all.append(recall) | ||
ap_all.append(average_precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't like this but not sure on the zip suggested in #30399 (comment) as you've got to unpack at the end 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some notes on review suggestions. Namely to make all the multi class params (precisions
, recalls
etc) list of ndarrays.
Also realised we did not need separate plot_single_curve
function, as most of the complexity was in _get_line_kwargs
if fold_line_kws is None: | ||
fold_line_kws = [ | ||
{"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided that we should not specify single colour because indeed the the legend would be useless.
names : str, default=None | ||
Names of each precision-recall curve for labeling. If `None`, use | ||
name provided at `PrecisionRecallDisplay` initialization. If not | ||
provided at initialization, no labeling is shown. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems reasonable that if we change the name
parameter in the class init, we should change it here to, especially as we don't advocate people to use plot
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed this with @glemaitre and decided that it is okay to change to names
. We should however make it clear what this is setting - the label of the curve in the legend.
The problem use case we thought about was if someone created a plot and display object, then wanted to add one curve to it using plot
, names
would not make sense in this case. However, it would be difficult for us to manage the legend in such a case, so decided that it would be up to the user to manage the legend in such a case.
if len(self.line_) == 1: | ||
self.line_ = self.line_[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should line_
always be a list or should we do this to be backwards compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We decided that we should deprecated line_
and add lines_
.
We'll add a getter such that if you try to access line_
you get a warning and the first item of lines_
, which will be removed in 2 releases.
Just wanted to document here that we discussed a potential enhancement for comparing between estimators, where you have cv results from several estimators (so several fold curves for each estimator). Potentially this could be added as a separate function, where you pass the display object, and estimators desired. Not planned, just a potential additional in future. |
Hey, I think that you can revive this PR now that the roc curve is merged. Let's try to reuse code from the other PR if possible :) |
Thanks @jeremiedbb ! I think @glemaitre mentioned there was some discussion about what to do with the 'chance' level (average precision). In the current PR I have calculated a single average precision (AP) for all the data. I think others suggested that we should calculate average precision for each fold, which I can see is more accurate but I am concerned about the visualization appearance. Here I have used 5 cv splits and plotting chance for each, and colouring each pair of precision-recall curve/chance line the same colour: Some concerns about the visualization:
I will have more of a think of a better solution for this. |
So for the "Chance level" I would consider all lines to have the same color (and a lower alpha) and in the legend to have a single entry showing the mean + std. I would think it is enough. Also it is easy to link a chance level line with its PR curve because they meet when the recall is 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is what the plot looks like with defaults, and plot chance set to True:
Code
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_validate
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import PrecisionRecallDisplay
import matplotlib.pyplot as plt
Generate sample data
X, y = make_classification(n_samples=1000, n_classes=2, n_informative=5, random_state=42)
clf = RandomForestClassifier(random_state=42)
cv_results = cross_validate(
clf, X, y, cv=5,
return_estimator=True,
return_indices=True,
)
# Plot Precision-Recall curve using from_cv_results
disp = PrecisionRecallDisplay.from_cv_results(
cv_results, X, y, name="RandomForest", plot_chance_level=True
)
plt.show()
The alpha for chance line is 0.3. Prevalence seems to be pretty much the same for all cvs (which may not be unusual?) so they mostly over-lap.
@@ -135,6 +135,8 @@ def _validate_curve_kwargs( | |||
legend_metric, | |||
legend_metric_name, | |||
curve_kwargs, | |||
default_curve_kwargs=None, | |||
removed_version="1.9", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly added a default here because I wanted this parameter to be last. Happy to change
_validate_style_kwargs({"label": label}, curve_kwargs[fold_idx]) | ||
_validate_style_kwargs( | ||
{"label": label, **default_curve_kwargs_}, curve_kwargs[fold_idx] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh github lost my comment.
Previously, if you passed any curve_kwargs
, it would over-ride all default kwargs.
Now, if an individual curve_kwargs
that changes the same parameter as a default kwarg is passed, only that parameter will be over-ridden. All other default kwargs will still be used. E.g., if the user set color to red in curve_kwargs
, only the defualt color parameter will be over-ridden. The other parameters (e.g., "alpha": 0.5, "linestyle": "--") will still be used.
I initially wanted to implement this in RocCurveDisplay, but just went with 'over-ride all defaults' because it was easier.
I think it is more likely that if a user e.g., sets the curve color to be red, they still want the other default kwargs (i.e., they only want to change the color).
In particular, I changed this because it is necessary for precision recall as I think we always want the default "drawstyle": "steps-post" (to prevent interpolation), unless the user specifically changes it.
(if we decide we are happy with this change, I should probably add a whats new entry for RocCurveDisplay)
precision_folds, recall_folds, ap_folds, prevalence_pos_label_folds = ( | ||
[], | ||
[], | ||
[], | ||
[], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lint did this and I don't think it looks great but I have no suggestions on how to fix 🤷
|
def test_precision_recall_display_string_labels(pyplot): |
When y
is composed of string labels:
from_predictions
raises an error ifpos_label
is not explicitly passed (via_check_pos_label_consistency
). This makes sense, as we cannot guess whatpos_label
should be.from_estimator
does not raise an error because we default toestimator.classes_[1]
(_get_response_values_binary
does this).
I think it is reasonable for from_cv_results
to also default to estimator.classes_[-1]
(this is indeed what we have in the docstring, but it is NOT what are doing in main). This case is a bit more complicated than from_estimator
because we have the problem where it is possible that not every class is present in each split (see #29558) - thus we could end up with different pos_labels
. Still thinking through this, but I think I would be happy to check that if pos_label
is not explicitly passed, it has been inferred to be the same for every split. WDYT @glemaitre ?
Edit: Actually, I think all estimators would raise an error if there are less than 2 classes, so we can just leave it to the estimator.
# y_multi[y_multi == 1] = 2 | ||
# with pytest.raises(ValueError, match=r"y takes value in \{0, 2\}"): | ||
# display_class.from_cv_results(cv_results, X, y_multi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've realised this is a very weird, edge case that I was testing. The estimator in cv_results
has been fitted on a different y
than what we have passed (y_multi
). i.e., estimator.classes_
will have different classes than np.unique(y)
Then when we pass values to the metric, pos_label
is not present in y_true
. Interestingly, average_precision_score
checks this in:
scikit-learn/sklearn/metrics/_ranking.py
Lines 244 to 250 in 8792943
present_labels = np.unique(y_true).tolist() | |
if y_type == "binary": | |
if len(present_labels) == 2 and pos_label not in present_labels: | |
raise ValueError( | |
f"pos_label={pos_label} is not a valid label. It should be " | |
f"one of {present_labels}" |
but none of precision_recall_curve
, roc_curve
and auc
, check this.
I am not sure if this is something we should be checking, and if so should it be left to the metric functions (to also avoid duplication of checking)...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realise that both roc_curve
and precision_recall_curve
will give a warning if pos_label
is not in y_true
, which gets ignored in tests. I can simply update this test to check that the correct warning is raised.
Reference Issues/PRs
Follows on from #30399
What does this implement/fix? Explain your changes.
Proof of concept of adding multi displays to
PrecisionRecallDisplay
from_cv_results
inRocCurveDisplay
(singleRocCurveDisplay
) #30399, so we can definitely factorize out, though small intricacies may make it complexplot
method is complex due to handling both single and multi curve and doing a lot more checking, as user is able to use it outside of thefrom_estimator
andfrom_predictions
methods.Detailed discussions of problems in review comments.
Any other comments?
cc @glemaitre @jeremiedbb