Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Validation] First implementation of @strict from huggingface_hub #36534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

gante
Copy link
Member

@gante gante commented Mar 4, 2025

What does this PR do?

Testbed for huggingface/huggingface_hub#2895, released recently.

Core ideas:

  • All config arguments are individually validated at __init__ and assignment time, using strict type checks and custom value validators;
  • Class-level validation happens after __init__ through dataclass's __post_init__, or manually through .validate() (⚠️ when mutating the config after initializing, .validate() must be called to validate at a class-level). With this, we can check e.g. if a special token is within the vocabulary range;
  • Arbitrary validators can be added as validate_xxx methods: @strict adds them to its validation functions.

Minimal test script:

from transformers import AlbertConfig

# These should work
config = AlbertConfig()
config = AlbertConfig(eos_token_id=5)
config = AlbertConfig(foo="bar")  # Ensures backward compatibility

# Manual specification, traveling through an invalid config, should be allowed
config.hidden_size = 65 # the config is impossible here
config.num_attention_heads = 5
config.validate()

# These will raise an exception
config = AlbertConfig(vocab_size=10.0)  # vocab_size is an int
config = AlbertConfig(num_hidden_layers=None)  # num_hidden_layers is an int
config = AlbertConfig(position_embedding_type="foo")  # position_embedding_type is a Literal, foo is not one of the options
config = AlbertConfig(hidden_size=65)  # hidden_size % num_attention_heads must be 0

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for giving it a try! I've left a few comments with some first thoughts

Comment on lines 1410 to 2030
if hasattr(config, "validate"): # e.g. in @strict_dataclass
config.validate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not done yet but yes, could be a solution for "class-wide validation" in addition to "per-attribute validation"

Comment on lines 149 to 150
def __post_init__(self):
self.validate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically something we should move to @strict_dataclass definition.
The validate method itself would have to be implemented by each class though

Comment on lines 28 to 82
def interval(min: Optional[int | float] = None, max: Optional[int | float] = None) -> Callable:
"""
Parameterized validator that ensures that `value` is within the defined interval.
Expected usage: `validated_field(interval(min=0), default=8)`
"""
error_message = "Value must be"
if min is not None:
error_message += f" at least {min}"
if min is not None and max is not None:
error_message += " and"
if max is not None:
error_message += f" at most {max}"
error_message += ", got {value}."

min = min or float("-inf")
max = max or float("inf")

def _inner(value: int | float):
if not min <= value <= max:
raise ValueError(error_message.format(value=value))

return _inner


def probability(value: float):
"""Ensures that `value` is a valid probability number, i.e. [0,1]."""
if not 0 <= value <= 1:
raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pydantic defines conint and confloat which does similar things but in a more generic way. I'd be up to reuse the same naming/interface since that's what people are used to.

For instance, interval(min=0) is unclear compared to conint(gt=0) or conint(ge=0). A probability would be confloat(ge=0.0, le=0.0), etc.

Copy link
Member Author

@gante gante Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to suggestions, but let me share my thought process :D

Given that we can write arbitrary validators, I would like to avoid that particular interface for two reasons:

  1. It heavily relies on non-obvious shortened names, which are a source of confusion -- and against our philosophy in transformers. My goal here was to write something a user without knowledge of strict_dataclass could read: for instance, interval with a min argument on an int should be immediately obvious, conint with gt requires some prior knowledge (con? gt?).
  2. Partial redundancy can result in better UX: a probability is the same as a float between 0.0 and 1.0, but we can write a more precise error message for the user. Take the case of dropout-related variables -- while it is technically a probability, it is (was?) a bad practice to set it to a value larger than 0.5, and we could be interested in throwing a warning in that case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points here! Ok to keep more explicit and redundant APIS then. I just want things to be precise when it comes to intervals (does "between 0.0 and 1.0" means "0.0" and "1.0" included or not?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

➕ for @gante 's comment

@gante
Copy link
Member Author

gante commented Mar 10, 2025

@Wauplin notwithstanding the syntax for validation in discussion above, the current commit is a nearly working version. If we pass standard arguments, everything happens as expected 🙌

However, I've hit one limitation, and I have one further request:

  1. limitation: our config classes accept arbitrary kwargs 😢 These are stored in the config as attributes. E.g. on main we can do this:
config = AlbertConfig(foo="bar")
print(config.foo)  # should print "bar"

I could keep the original __init__ method in AlbertConfig, which accepts **kwargs. However, that would add a lot of boilerplate to each config class, resulting in less readable code than without @strict_dataclass. Do you think it would be possible for strict_dataclass to overload the __init__ creating method in dataclass, so as to have an option to accept arbitrary kwargs? On each class, we could then do @strict_dataclass(accept_kwargs=True), storing the **kwargs on self.kwargs, and on __post_init__ we would handle them on our side.

  1. wishlist: can we pass dataclass kwargs into strict_dataclass? I would like to keep the original __repr__ here for now, i.e. @strict_dataclass(repr=False) 🤗

@Wauplin
Copy link
Contributor

Wauplin commented Mar 12, 2025

Do you think it would be possible for strict_dataclass to overload the init creating method in dataclass, so as to have an option to accept arbitrary kwargs?

Definitely possible yes!

we could then do @strict_dataclass(accept_kwargs=True), storing the **kwargs on self.kwargs

So in the example above, you would have

config = AlbertConfig(foo="bar")
print(config.kwargs["foo"])  # should print "bar"

?
(just want to be sure we are aligned since it's not exactly the example you've mentioned. Both solutions are ok to me).

wishlist: can we pass dataclass kwargs into strict_dataclass? I would like to keep the original repr here for now, i.e. @strict_dataclass(repr=False) 🤗

Feature request accepted 🤗

@gante
Copy link
Member Author

gante commented Mar 17, 2025

@Wauplin

So in the example above, you would have
config = AlbertConfig(foo="bar")
print(config.kwargs["foo"]) # should print "bar"
?
(just want to be sure we are aligned since it's not exactly the example you've mentioned. Both solutions are ok to me).

To be fully BC, config.foo (!= config.kwargs["foo"]) would have to be valid in this case 😢

What I had in mind, to avoid specifying an __init__ in each class to ensure BC:

  • decorate with @strict_dataclass(accept_kwargs=True)

and either (option A)

  • strict_dataclass took care of mapping kwargs into a self.kwargs attribute
  • __post_init__ in each class took care of setting attributes from self.kwargs

or (option B)

  • strict_dataclass would set the extra kwargs as instance attributes without validation, e.g. with a loop over kwargs with setattr

@gante
Copy link
Member Author

gante commented Apr 29, 2025

@Wauplin: Reviving this project. The current issue is that transformers needs to accept (unvalidated) kwargs for BC. WDYT of having the hub class as-is, and we in transformers expanding the hub class to meet our unorthodox needs?

hub would then hold a technically sound class for everyone to use 🤗

@gante
Copy link
Member Author

gante commented Apr 29, 2025

Potentially related: vllm-project/vllm#14764 -> add shape checkers

@Wauplin
Copy link
Contributor

Wauplin commented Apr 30, 2025

@gante I reviewed and updated my previous PR. The main change is that instead of the previous @strict_dataclass you must use both @strict and @dataclass. Without this I couldn't have perfect type hinting + autocompletion in IDEs which I think is super useful to have

The current issue is that transformers needs to accept (unvalidated) kwargs for BC. WDYT of having the hub class as-is, and we in transformers expanding the hub class to meet our unorthodox needs?

This is now doable yes!

@strict(accept_kwargs=True)
@dataclass
class ConfigWithKwargs:
    model_type: str
    vocab_size: int = 16

config = ConfigWithKwargs(model_type="bert", vocab_size=30000, extra_field="extra_value")
print(config)  # ConfigWithKwargs(model_type='bert', vocab_size=30000, *extra_field='extra_value')

Check out documentation on https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses

@gante gante marked this pull request as ready for review May 2, 2025 11:30
Comment on lines 136 to 138
def __post_init__(self):
"""Called after `__init__`: validates the instance."""
self.validate()
Copy link
Member Author

@gante gante May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin as you mentioned in a comment in a prior commit: this could be moved to strict (but it's a minor thing, happy to keep in transformers)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have it in huggingface_hub directly!

@gante gante requested a review from ArthurZucker May 2, 2025 11:34
@gante gante changed the title [Validation] First implementation of @strict_dataclass from huggingface_hub [Validation] First implementation of @strict from huggingface_hub May 2, 2025
Comment on lines +1855 to +2030
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`)
if hasattr(config, "validate"):
config.validate()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validates the config at model init time

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it!
WOuld be nice to have:

    vocab_size: int = check(int, default=...)

something even more minimal

Comment on lines 149 to 156
# Token validation
for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]:
token_id = getattr(self, token_name)
if token_id is not None and not 0 <= token_id < self.vocab_size:
raise ValueError(
f"{token_name} must be in the vocabulary with size {self.vocab_size}, i.e. between 0 and "
f"{self.vocab_size - 1}, got {token_id}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these can be move the general PreTrainedConfig 's validate function for the common args

Copy link
Member Author

@gante gante May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

And they would have to be lifted to warnings, I'm sure there are config files out there with negative tokens (this was a common trick in the past to e.g. manipulate generation into not stopping on EOS)

@@ -103,51 +109,51 @@ class AlbertConfig(PretrainedConfig):
>>> configuration = model.config
```"""

vocab_size: int = validated_field(interval(min=1), default=30000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to juste write validate or check instead of check_field?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(cc @Wauplin)

num_attention_heads: int = validated_field(interval(min=0), default=64)
intermediate_size: int = validated_field(interval(min=1), default=16384)
inner_group_num: int = validated_field(interval(min=0), default=1)
hidden_act: str = validated_field(activation_fn_key, default="gelu_new")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if activation_fn_key is a key, I would rather use ActivationFnKey as a class

Suggested change
hidden_act: str = validated_field(activation_fn_key, default="gelu_new")
hidden_act: str = check_activation_function(default="gelu_new")

Copy link
Member Author

@gante gante May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In python 3.11, we can define something like ActivationFnKey = Literal[possible values], and then simply have

hidden_act: ActivationFnKey = "gelu_new"

Until then, I think it's best to keep the same pattern as in other validated fields, i.e. validated_field(validation_fn, default). WDYT?

Comment on lines 113 to 119
embedding_size: int = validated_field(interval(min=1), default=128)
hidden_size: int = validated_field(interval(min=1), default=4096)
num_hidden_layers: int = validated_field(interval(min=1), default=12)
num_hidden_groups: int = validated_field(interval(min=1), default=1)
num_attention_heads: int = validated_field(interval(min=0), default=64)
intermediate_size: int = validated_field(interval(min=1), default=16384)
inner_group_num: int = validated_field(interval(min=0), default=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the interval are too much, we only need the default here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you prefer something like

embedding_size: int = validated_field(strictly_non_negative, default=128)

or simply

embedding_size: int = 128

?

IMO, since most people don't read the config classes, I think validation is a nice sanity check 🤗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker @gante I've pushed an update that should greatly reduce the verbosity of validated fields. You can use the @as_validated_field decorator when defining your validators which avoid using validated_field on each field. Here is an example:

from dataclasses import dataclass
from typing import Optional

from huggingface_hub.dataclasses import as_validated_field, strict, validated_field


# Case 1: decorated validator with no arguments


@as_validated_field
def probability(value: int):
    if not 0 <= value <= 1:
        raise ValueError(f"Value must be in interval [0, 1], got {value}")


# Case 2: decorated validator with arguments


def interval(min: Optional[int] = None, max: Optional[int] = None):
    @as_validated_field
    def _inner(value: int) -> None:
        if min is not None and value < min:
            raise ValueError(f"Value must be greater than {min}, got {value}")
        if max is not None and value > max:
            raise ValueError(f"Value must be less than {max}, got {value}")

    return _inner


# Case 3: multiple validators


def positive(value: int) -> None:
    if not value >= 0:
        raise ValueError(f"Value must be positive, got {value}")


def multiple_of_2(value: int) -> None:
    if value % 2 != 0:
        raise ValueError(f"Value must be a multiple of 2, got {value}")


@strict
@dataclass
class Config:
    # No custom validation (only type checking)
    model_type: str

    # Validator defined using the decorator
    hidden_dropout_prob: float = probability(default=0.0)

    # Validator with args defined using the decorator
    vocab_size: int = interval(min=10)(default=16)

    # Type checking + 2 validators (more verbose but more explicit)
    hidden_size: int = validated_field([positive, multiple_of_2], default=32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this syntax and we correctly defined validators, the AlbertConfig would become:

@strict(accept_kwargs=True)
@dataclass
class AlbertConfig(PretrainedConfig):
    vocab_size: int = interval(min=1)(default=30000)
    embedding_size: int = interval(min=1)(default=128)
    hidden_size: int = interval(min=1)(default=4096)
    num_hidden_layers: int = interval(min=1)(default=12)
    num_hidden_groups: int = interval(min=1)(default=1)
    num_attention_heads: int = interval(min=0)(default=64)
    intermediate_size: int = interval(min=1)(default=16384)
    inner_group_num: int = interval(min=0)(default=1)
    hidden_act: str = activation_fn_key(default="gelu_new")
    hidden_dropout_prob: float = probability(default=0.0)
    attention_probs_dropout_prob: float = probability(default=0.0)
    max_position_embeddings: int = interval(min=0)(default=512)
    type_vocab_size: int = interval(min=1)(default=2)
    initializer_range: float = interval(min=0.0)(default=0.02)
    layer_norm_eps: float = interval(min=0.0)(default=1e-12)
    classifier_dropout_prob: float = probability(default=0.1)
    position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute"
    pad_token_id: Optional[int] = token(default=0)
    bos_token_id: Optional[int] = token(default=2)
    eos_token_id: Optional[int] = token(default=3)

which can hardly be less verbose IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

@gante
Copy link
Member Author

gante commented May 14, 2025

@Wauplin PR updated with the latest syntax 🙌 Moving __post_init__ to the hub-side would be a nice addition :)

@ArthurZucker syntax is much shorter now 🤗

@gante gante requested a review from ArthurZucker May 14, 2025 14:17
@Wauplin
Copy link
Contributor

Wauplin commented May 14, 2025

PR updated with the latest syntax 🙌

🎉

Moving post_init to the hub-side would be a nice addition :)

huggingface/huggingface_hub@27ca2d7 🫶 with docs in https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses#class-validators

@Wauplin
Copy link
Contributor

Wauplin commented May 14, 2025

With class validators, you can now do that:

@strict(accept_kwargs=True)
@dataclass
class AlbertConfig(PretrainedConfig):
    vocab_size: int = interval(min=1)(default=30000)
    embedding_size: int = interval(min=1)(default=128)
    hidden_size: int = interval(min=1)(default=4096)
    num_hidden_layers: int = interval(min=1)(default=12)
    num_hidden_groups: int = interval(min=1)(default=1)
    num_attention_heads: int = interval(min=0)(default=64)
    intermediate_size: int = interval(min=1)(default=16384)
    inner_group_num: int = interval(min=0)(default=1)
    hidden_act: str = activation_fn_key(default="gelu_new")
    hidden_dropout_prob: float = probability(default=0.0)
    attention_probs_dropout_prob: float = probability(default=0.0)
    max_position_embeddings: int = interval(min=0)(default=512)
    type_vocab_size: int = interval(min=1)(default=2)
    initializer_range: float = interval(min=0.0)(default=0.02)
    layer_norm_eps: float = interval(min=0.0)(default=1e-12)
    classifier_dropout_prob: float = probability(default=0.1)
    position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute"
    pad_token_id: Optional[int] = token(default=0)
    bos_token_id: Optional[int] = token(default=2)
    eos_token_id: Optional[int] = token(default=3)

    # Not part of __init__
    model_type = "albert"

    def validate_architecture(self):
        """Validates the architecture of the model."""
        # Check if the number of attention heads is a divisor of the hidden size
        if self.hidden_size % self.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({self.hidden_size}) must be divisible by the number of attention "
                f"heads ({self.num_attention_heads})."
            )

=> validate_architecture is automatically considered as a class validator. Class validators are executed once after __post_init__ and can be re-executed anytime by calling config.validate() explicitly. Note that .validate() is not ran after each field attribution as it could block some use cases (e.g. if 2 related fields have to be updated at once)

Documentation: https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses#class-validators

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(tiny review)

Comment on lines +1855 to +2031
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`)
if hasattr(config, "validate"):
config.validate()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`)
if hasattr(config, "validate"):
config.validate()

not needed anymore

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to load a config, modify it to an invalid state, and then use that config to instantiate a model :( as such, the model should ensure the config is valid before committing resources

@gante gante force-pushed the validate_config branch from 493cacb to bf3944b Compare May 29, 2025 16:21
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante force-pushed the validate_config branch from cf0e3ea to 7152997 Compare May 30, 2025 13:09
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In favor of removing interval when its obvious!

Comment on lines +112 to +120
vocab_size: int = interval(min=1)(default=30000)
embedding_size: int = interval(min=1)(default=128)
hidden_size: int = interval(min=1)(default=4096)
num_hidden_layers: int = interval(min=1)(default=12)
num_hidden_groups: int = interval(min=1)(default=1)
num_attention_heads: int = interval(min=0)(default=64)
intermediate_size: int = interval(min=1)(default=16384)
inner_group_num: int = interval(min=0)(default=1)
hidden_act: str = activation_fn_key(default="gelu_new")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I think we should only use default here. let's be less verbose when not needed! Not even sure we need to specify this being an interval!

@gante gante force-pushed the validate_config branch from 64cd349 to 48e9c71 Compare May 30, 2025 15:49
@gante
Copy link
Member Author

gante commented May 30, 2025

Note to myself before I go on holidays


Now that strict is released, we can easily run the full CI to detect issues. The PR is nearly in a compatible state (see CI status), but it wouldn't be wise to merge even if CI was green.

Let's ignore the validation logic itself, which is working well, and focus on the pressing issue: handling inheritance of PretrainedConfig. There are a few implementation details of the class that make inheritance challenging, which I'll try to isolate.

Issue: Base class inheritance

We are defining a model-level dataclass that inherits another class (non-dataclass). This makes attribute inheritance weird, especially given the base class init. I think it can be solved with:

  1. converting the base class to a dataclass (with @strict, to accept kwargs)
  2. adding a __post_init__ to add defaults
  3. attributes that are not stored under the same name will need a getter and a setter (e.g. attn_implementation, see latest commit)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants