-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
FIX default metadata requests when set via class attributes #28435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Not sure if it's linked or not, but I observed another kind of strange behaviour with class inheritance. Expected result: Actual result: |
class Base_1(BaseEstimator):
__metadata_request__split = {"groups": metadata_routing.UNUSED}here you're trying to remove |
Makes sense 👍 |
thomasjpfan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to dislike walking through the MRO, but in this case the alternative is not great either:
diff --git a/sklearn/utils/_metadata_requests.py b/sklearn/utils/_metadata_requests.py
index a914addc7f..7cd0b12a0e 100644
--- a/sklearn/utils/_metadata_requests.py
+++ b/sklearn/utils/_metadata_requests.py
@@ -1403,6 +1403,24 @@ def _build_request_for_signature(cls, router, method):
)
return mmr
+ @classmethod
+ def _configure_metadata_request(cls, requests):
+ from contextlib import suppress
+
+ with suppress(AttributeError):
+ requests = super()._configure_metadata_request(requests)
+
+ substr = "__metadata_request__"
+ for attr, value in vars(cls).items():
+ if substr not in attr:
+ continue
+
+ method = attr[attr.index(substr) + len(substr) :]
+ for prop, alias in value.items():
+ getattr(requests, method).add_request(param=prop, alias=alias)
+
+ return requests
+
@classmethod
def _get_default_requests(cls):
"""Collect default request values.
@@ -1424,28 +1442,7 @@ class attributes, as well as determining request keys from method
# __metadata_request__* attributes. Defaults set in
# __metadata_request__* attributes take precedence over signature
# sniffing.
-
- # need to go through the MRO since this is a class attribute and
- # ``vars`` doesn't report the parent class attributes. We go through
- # the reverse of the MRO so that child classes have precedence over
- # their parents.
- substr = "__metadata_request__"
- for base_class in reversed(inspect.getmro(cls)):
- for attr, value in vars(base_class).items():
- if substr not in attr:
- continue
- # we don't check for attr.startswith() since python prefixes attrs
- # starting with __ with the `_ClassName`.
- method = attr[attr.index(substr) + len(substr) :]
- for prop, alias in value.items():
- # here we "add" request values specified via those class attributes
- # to the `MetadataRequest` object. Adding a request which already
- # exists will override the previous one. Since we go through the
- # MRO in reverse order, the one specified by the lowest most classes
- # in the inheritance tree are the ones which take effect.
- getattr(requests, method).add_request(param=prop, alias=alias)
-
- return requests
+ return cls._configure_metadata_request(requests)
def _get_metadata_request(self):
"""Get requested data properties.I am okay with the current approach in this PR.
glemaitre
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Co-authored-by: Thomas J. Fan <[email protected]>
|
Enable auto-merge. |
Fixes #28430
Closes #28431
This fixes an issue where previously we were going through the MRO, sorting values, then writing and overwriting defaults. Now we create the request values as we go through MRO, which avoids previous issues.
cc @YanisLalou, @thomasjpfan @glemaitre @OmarManzoor