-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[Validation] First implementation of @strict
from huggingface_hub
#36534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for giving it a try! I've left a few comments with some first thoughts
src/transformers/modeling_utils.py
Outdated
if hasattr(config, "validate"): # e.g. in @strict_dataclass | ||
config.validate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not done yet but yes, could be a solution for "class-wide validation" in addition to "per-attribute validation"
def __post_init__(self): | ||
self.validate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically something we should move to @strict_dataclass
definition.
The validate
method itself would have to be implemented by each class though
src/transformers/validators.py
Outdated
def interval(min: Optional[int | float] = None, max: Optional[int | float] = None) -> Callable: | ||
""" | ||
Parameterized validator that ensures that `value` is within the defined interval. | ||
Expected usage: `validated_field(interval(min=0), default=8)` | ||
""" | ||
error_message = "Value must be" | ||
if min is not None: | ||
error_message += f" at least {min}" | ||
if min is not None and max is not None: | ||
error_message += " and" | ||
if max is not None: | ||
error_message += f" at most {max}" | ||
error_message += ", got {value}." | ||
|
||
min = min or float("-inf") | ||
max = max or float("inf") | ||
|
||
def _inner(value: int | float): | ||
if not min <= value <= max: | ||
raise ValueError(error_message.format(value=value)) | ||
|
||
return _inner | ||
|
||
|
||
def probability(value: float): | ||
"""Ensures that `value` is a valid probability number, i.e. [0,1].""" | ||
if not 0 <= value <= 1: | ||
raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pydantic defines conint
and confloat
which does similar things but in a more generic way. I'd be up to reuse the same naming/interface since that's what people are used to.
- https://docs.pydantic.dev/1.10/usage/types/#arguments-to-conint
- https://docs.pydantic.dev/1.10/usage/types/#arguments-to-confloat
For instance, interval(min=0)
is unclear compared to conint(gt=0)
or conint(ge=0)
. A probability would be confloat(ge=0.0, le=0.0)
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to suggestions, but let me share my thought process :D
Given that we can write arbitrary validators, I would like to avoid that particular interface for two reasons:
- It heavily relies on non-obvious shortened names, which are a source of confusion -- and against our philosophy in transformers. My goal here was to write something a user without knowledge of
strict_dataclass
could read: for instance,interval
with amin
argument on anint
should be immediately obvious,conint
withgt
requires some prior knowledge (con
?gt
?). - Partial redundancy can result in better UX: a probability is the same as a
float
between0.0
and1.0
, but we can write a more precise error message for the user. Take the case ofdropout
-related variables -- while it is technically a probability, it is (was?) a bad practice to set it to a value larger than 0.5, and we could be interested in throwing a warning in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points here! Ok to keep more explicit and redundant APIS then. I just want things to be precise when it comes to intervals (does "between 0.0 and 1.0" means "0.0" and "1.0" included or not?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
➕ for @gante 's comment
@Wauplin notwithstanding the syntax for validation in discussion above, the current commit is a nearly working version. If we pass standard arguments, everything happens as expected 🙌 However, I've hit one limitation, and I have one further request:
config = AlbertConfig(foo="bar")
print(config.foo) # should print "bar" I could keep the original
|
Definitely possible yes!
So in the example above, you would have config = AlbertConfig(foo="bar")
print(config.kwargs["foo"]) # should print "bar" ?
Feature request accepted 🤗 |
To be fully BC, What I had in mind, to avoid specifying an
and either (option A)
or (option B)
|
@Wauplin: Reviving this project. The current issue is that hub would then hold a technically sound class for everyone to use 🤗 |
Potentially related: vllm-project/vllm#14764 -> add shape checkers |
@gante I reviewed and updated my previous PR. The main change is that instead of the previous
This is now doable yes! @strict(accept_kwargs=True)
@dataclass
class ConfigWithKwargs:
model_type: str
vocab_size: int = 16
config = ConfigWithKwargs(model_type="bert", vocab_size=30000, extra_field="extra_value")
print(config) # ConfigWithKwargs(model_type='bert', vocab_size=30000, *extra_field='extra_value') Check out documentation on https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses |
def __post_init__(self): | ||
"""Called after `__init__`: validates the instance.""" | ||
self.validate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Wauplin as you mentioned in a comment in a prior commit: this could be moved to strict
(but it's a minor thing, happy to keep in transformers)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's have it in huggingface_hub
directly!
@strict_dataclass
from huggingface_hub
@strict
from huggingface_hub
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`) | ||
if hasattr(config, "validate"): | ||
config.validate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validates the config at model init time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
love it!
WOuld be nice to have:
vocab_size: int = check(int, default=...)
something even more minimal
# Token validation | ||
for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: | ||
token_id = getattr(self, token_name) | ||
if token_id is not None and not 0 <= token_id < self.vocab_size: | ||
raise ValueError( | ||
f"{token_name} must be in the vocabulary with size {self.vocab_size}, i.e. between 0 and " | ||
f"{self.vocab_size - 1}, got {token_id}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these can be move the general PreTrainedConfig
's validate function for the common args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
And they would have to be lifted to warnings, I'm sure there are config files out there with negative tokens (this was a common trick in the past to e.g. manipulate generation into not stopping on EOS)
@@ -103,51 +109,51 @@ class AlbertConfig(PretrainedConfig): | |||
>>> configuration = model.config | |||
```""" | |||
|
|||
vocab_size: int = validated_field(interval(min=1), default=30000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to juste write validate
or check
instead of check_field
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(cc @Wauplin)
num_attention_heads: int = validated_field(interval(min=0), default=64) | ||
intermediate_size: int = validated_field(interval(min=1), default=16384) | ||
inner_group_num: int = validated_field(interval(min=0), default=1) | ||
hidden_act: str = validated_field(activation_fn_key, default="gelu_new") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if activation_fn_key
is a key, I would rather use ActivationFnKey
as a class
hidden_act: str = validated_field(activation_fn_key, default="gelu_new") | |
hidden_act: str = check_activation_function(default="gelu_new") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In python 3.11, we can define something like ActivationFnKey = Literal[possible values]
, and then simply have
hidden_act: ActivationFnKey = "gelu_new"
Until then, I think it's best to keep the same pattern as in other validated fields, i.e. validated_field(validation_fn, default)
. WDYT?
embedding_size: int = validated_field(interval(min=1), default=128) | ||
hidden_size: int = validated_field(interval(min=1), default=4096) | ||
num_hidden_layers: int = validated_field(interval(min=1), default=12) | ||
num_hidden_groups: int = validated_field(interval(min=1), default=1) | ||
num_attention_heads: int = validated_field(interval(min=0), default=64) | ||
intermediate_size: int = validated_field(interval(min=1), default=16384) | ||
inner_group_num: int = validated_field(interval(min=0), default=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the interval are too much, we only need the default here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you prefer something like
embedding_size: int = validated_field(strictly_non_negative, default=128)
or simply
embedding_size: int = 128
?
IMO, since most people don't read the config classes, I think validation is a nice sanity check 🤗
What does this PR do?
Testbed for huggingface/huggingface_hub#2895
⚠️ this PR can only be merged after the PR above is merged and part of a
huggingface_hub
release!Core ideas:
__init__
and assignment time, using strict type checks and custom value validators;__init__
throughdataclass
's__post_init__
, or manually through.validate()
(.validate()
must be called to validate at a class-level). With this, we can check e.g. if a special token is within the vocabulary range.Minimal test script: