Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Validation] First implementation of @strict from huggingface_hub #36534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

gante
Copy link
Member

@gante gante commented Mar 4, 2025

What does this PR do?

Testbed for huggingface/huggingface_hub#2895
⚠️ this PR can only be merged after the PR above is merged and part of a huggingface_hub release!

Core ideas:

  • All config arguments are individually validated at __init__ and assignment time, using strict type checks and custom value validators;
  • Class-level validation happens after __init__ through dataclass's __post_init__, or manually through .validate() (⚠️ when mutating the config after initializing, .validate() must be called to validate at a class-level). With this, we can check e.g. if a special token is within the vocabulary range.

Minimal test script:

from transformers import AlbertConfig

# These should work
config = AlbertConfig()
config = AlbertConfig(eos_token_id=5)
config = AlbertConfig(foo="bar")  # Ensures backward compatibility

# Manual specification, traveling through an invalid config, should be allowed
config.eos_token_id = 99
config.vocab_size = 10  # the config is impossible here
config.eos_token_id = 9
config.validate()

# These will raise an exception
config = AlbertConfig(vocab_size=10.0)  # vocab_size is an int
config = AlbertConfig(num_hidden_layers=None)  # num_hidden_layers is an int
config = AlbertConfig(position_embedding_type="foo")  # position_embedding_type is a Literal, foo is not one of the options
config = AlbertConfig(vocab_size=10, eos_token_id=99)  # eos_token_id must be in the vocab

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for giving it a try! I've left a few comments with some first thoughts

Comment on lines 1410 to 1411
if hasattr(config, "validate"): # e.g. in @strict_dataclass
config.validate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not done yet but yes, could be a solution for "class-wide validation" in addition to "per-attribute validation"

Comment on lines 149 to 150
def __post_init__(self):
self.validate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically something we should move to @strict_dataclass definition.
The validate method itself would have to be implemented by each class though

Comment on lines 28 to 55
def interval(min: Optional[int | float] = None, max: Optional[int | float] = None) -> Callable:
"""
Parameterized validator that ensures that `value` is within the defined interval.
Expected usage: `validated_field(interval(min=0), default=8)`
"""
error_message = "Value must be"
if min is not None:
error_message += f" at least {min}"
if min is not None and max is not None:
error_message += " and"
if max is not None:
error_message += f" at most {max}"
error_message += ", got {value}."

min = min or float("-inf")
max = max or float("inf")

def _inner(value: int | float):
if not min <= value <= max:
raise ValueError(error_message.format(value=value))

return _inner


def probability(value: float):
"""Ensures that `value` is a valid probability number, i.e. [0,1]."""
if not 0 <= value <= 1:
raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pydantic defines conint and confloat which does similar things but in a more generic way. I'd be up to reuse the same naming/interface since that's what people are used to.

For instance, interval(min=0) is unclear compared to conint(gt=0) or conint(ge=0). A probability would be confloat(ge=0.0, le=0.0), etc.

Copy link
Member Author

@gante gante Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to suggestions, but let me share my thought process :D

Given that we can write arbitrary validators, I would like to avoid that particular interface for two reasons:

  1. It heavily relies on non-obvious shortened names, which are a source of confusion -- and against our philosophy in transformers. My goal here was to write something a user without knowledge of strict_dataclass could read: for instance, interval with a min argument on an int should be immediately obvious, conint with gt requires some prior knowledge (con? gt?).
  2. Partial redundancy can result in better UX: a probability is the same as a float between 0.0 and 1.0, but we can write a more precise error message for the user. Take the case of dropout-related variables -- while it is technically a probability, it is (was?) a bad practice to set it to a value larger than 0.5, and we could be interested in throwing a warning in that case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points here! Ok to keep more explicit and redundant APIS then. I just want things to be precise when it comes to intervals (does "between 0.0 and 1.0" means "0.0" and "1.0" included or not?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

➕ for @gante 's comment

@gante
Copy link
Member Author

gante commented Mar 10, 2025

@Wauplin notwithstanding the syntax for validation in discussion above, the current commit is a nearly working version. If we pass standard arguments, everything happens as expected 🙌

However, I've hit one limitation, and I have one further request:

  1. limitation: our config classes accept arbitrary kwargs 😢 These are stored in the config as attributes. E.g. on main we can do this:
config = AlbertConfig(foo="bar")
print(config.foo)  # should print "bar"

I could keep the original __init__ method in AlbertConfig, which accepts **kwargs. However, that would add a lot of boilerplate to each config class, resulting in less readable code than without @strict_dataclass. Do you think it would be possible for strict_dataclass to overload the __init__ creating method in dataclass, so as to have an option to accept arbitrary kwargs? On each class, we could then do @strict_dataclass(accept_kwargs=True), storing the **kwargs on self.kwargs, and on __post_init__ we would handle them on our side.

  1. wishlist: can we pass dataclass kwargs into strict_dataclass? I would like to keep the original __repr__ here for now, i.e. @strict_dataclass(repr=False) 🤗

@Wauplin
Copy link
Contributor

Wauplin commented Mar 12, 2025

Do you think it would be possible for strict_dataclass to overload the init creating method in dataclass, so as to have an option to accept arbitrary kwargs?

Definitely possible yes!

we could then do @strict_dataclass(accept_kwargs=True), storing the **kwargs on self.kwargs

So in the example above, you would have

config = AlbertConfig(foo="bar")
print(config.kwargs["foo"])  # should print "bar"

?
(just want to be sure we are aligned since it's not exactly the example you've mentioned. Both solutions are ok to me).

wishlist: can we pass dataclass kwargs into strict_dataclass? I would like to keep the original repr here for now, i.e. @strict_dataclass(repr=False) 🤗

Feature request accepted 🤗

@gante
Copy link
Member Author

gante commented Mar 17, 2025

@Wauplin

So in the example above, you would have
config = AlbertConfig(foo="bar")
print(config.kwargs["foo"]) # should print "bar"
?
(just want to be sure we are aligned since it's not exactly the example you've mentioned. Both solutions are ok to me).

To be fully BC, config.foo (!= config.kwargs["foo"]) would have to be valid in this case 😢

What I had in mind, to avoid specifying an __init__ in each class to ensure BC:

  • decorate with @strict_dataclass(accept_kwargs=True)

and either (option A)

  • strict_dataclass took care of mapping kwargs into a self.kwargs attribute
  • __post_init__ in each class took care of setting attributes from self.kwargs

or (option B)

  • strict_dataclass would set the extra kwargs as instance attributes without validation, e.g. with a loop over kwargs with setattr

@gante
Copy link
Member Author

gante commented Apr 29, 2025

@Wauplin: Reviving this project. The current issue is that transformers needs to accept (unvalidated) kwargs for BC. WDYT of having the hub class as-is, and we in transformers expanding the hub class to meet our unorthodox needs?

hub would then hold a technically sound class for everyone to use 🤗

@gante
Copy link
Member Author

gante commented Apr 29, 2025

Potentially related: vllm-project/vllm#14764 -> add shape checkers

@Wauplin
Copy link
Contributor

Wauplin commented Apr 30, 2025

@gante I reviewed and updated my previous PR. The main change is that instead of the previous @strict_dataclass you must use both @strict and @dataclass. Without this I couldn't have perfect type hinting + autocompletion in IDEs which I think is super useful to have

The current issue is that transformers needs to accept (unvalidated) kwargs for BC. WDYT of having the hub class as-is, and we in transformers expanding the hub class to meet our unorthodox needs?

This is now doable yes!

@strict(accept_kwargs=True)
@dataclass
class ConfigWithKwargs:
    model_type: str
    vocab_size: int = 16

config = ConfigWithKwargs(model_type="bert", vocab_size=30000, extra_field="extra_value")
print(config)  # ConfigWithKwargs(model_type='bert', vocab_size=30000, *extra_field='extra_value')

Check out documentation on https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses

@gante gante marked this pull request as ready for review May 2, 2025 11:30
Comment on lines +136 to +138
def __post_init__(self):
"""Called after `__init__`: validates the instance."""
self.validate()
Copy link
Member Author

@gante gante May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin as you mentioned in a comment in a prior commit: this could be moved to strict (but it's a minor thing, happy to keep in transformers)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have it in huggingface_hub directly!

@gante gante requested a review from ArthurZucker May 2, 2025 11:34
@gante gante changed the title [Validation] First implementation of @strict_dataclass from huggingface_hub [Validation] First implementation of @strict from huggingface_hub May 2, 2025
Comment on lines +1855 to +1857
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`)
if hasattr(config, "validate"):
config.validate()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validates the config at model init time

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it!
WOuld be nice to have:

    vocab_size: int = check(int, default=...)

something even more minimal

Comment on lines +149 to +156
# Token validation
for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]:
token_id = getattr(self, token_name)
if token_id is not None and not 0 <= token_id < self.vocab_size:
raise ValueError(
f"{token_name} must be in the vocabulary with size {self.vocab_size}, i.e. between 0 and "
f"{self.vocab_size - 1}, got {token_id}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these can be move the general PreTrainedConfig 's validate function for the common args

Copy link
Member Author

@gante gante May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

And they would have to be lifted to warnings, I'm sure there are config files out there with negative tokens (this was a common trick in the past to e.g. manipulate generation into not stopping on EOS)

@@ -103,51 +109,51 @@ class AlbertConfig(PretrainedConfig):
>>> configuration = model.config
```"""

vocab_size: int = validated_field(interval(min=1), default=30000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to juste write validate or check instead of check_field?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(cc @Wauplin)

num_attention_heads: int = validated_field(interval(min=0), default=64)
intermediate_size: int = validated_field(interval(min=1), default=16384)
inner_group_num: int = validated_field(interval(min=0), default=1)
hidden_act: str = validated_field(activation_fn_key, default="gelu_new")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if activation_fn_key is a key, I would rather use ActivationFnKey as a class

Suggested change
hidden_act: str = validated_field(activation_fn_key, default="gelu_new")
hidden_act: str = check_activation_function(default="gelu_new")

Copy link
Member Author

@gante gante May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In python 3.11, we can define something like ActivationFnKey = Literal[possible values], and then simply have

hidden_act: ActivationFnKey = "gelu_new"

Until then, I think it's best to keep the same pattern as in other validated fields, i.e. validated_field(validation_fn, default). WDYT?

Comment on lines +113 to +119
embedding_size: int = validated_field(interval(min=1), default=128)
hidden_size: int = validated_field(interval(min=1), default=4096)
num_hidden_layers: int = validated_field(interval(min=1), default=12)
num_hidden_groups: int = validated_field(interval(min=1), default=1)
num_attention_heads: int = validated_field(interval(min=0), default=64)
intermediate_size: int = validated_field(interval(min=1), default=16384)
inner_group_num: int = validated_field(interval(min=0), default=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the interval are too much, we only need the default here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you prefer something like

embedding_size: int = validated_field(strictly_non_negative, default=128)

or simply

embedding_size: int = 128

?

IMO, since most people don't read the config classes, I think validation is a nice sanity check 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants