Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 27ca2d7

Browse files
committed
class validators
1 parent 4c72af3 commit 27ca2d7

4 files changed

Lines changed: 217 additions & 5 deletions

File tree

docs/source/en/package_reference/dataclasses.md

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,49 @@ def positive_int(value: int):
3636
@dataclass
3737
class Config:
3838
model_type: str
39-
hidden_size: int = positive_int(default=32)
40-
vocab_size: int = 16 # Default value
39+
hidden_size: int = positive_int(default=16)
40+
vocab_size: int = 32 # Default value
41+
42+
def validate_big_enough_vocab(self):
43+
if self.vocab_size < self.hidden_size:
44+
raise ValueError(f"vocab_size ({self.vocab_size}) must be greater than hidden_size ({self.hidden_size})")
4145
```
4246

4347
Fields are validated during initialization:
4448

4549
```python
46-
config = Config(model_type="bert", hidden_size=768) # Valid
50+
config = Config(model_type="bert", hidden_size=24) # Valid
4751
config = Config(model_type="bert", hidden_size=-1) # Raises StrictDataclassFieldValidationError
4852
```
4953

54+
Consistency between fields is also validated during initialization (class-wise validation):
55+
56+
```python
57+
# `vocab_size` too small compared to `hidden_size`
58+
config = Config(model_type="bert", hidden_size=32, vocab_size=16) # Raises StrictDataclassClassValidationError
59+
```
60+
5061
Fields are also validated during assignment:
5162

5263
```python
5364
config.hidden_size = 512 # Valid
5465
config.hidden_size = -1 # Raises StrictDataclassFieldValidationError
5566
```
5667

68+
To re-run class-wide validation after assignment, you must call `.validate` explicitly:
69+
70+
```python
71+
config.validate() # Runs all class validators
72+
```
73+
5774
### Custom Validators
5875

5976
You can attach multiple custom validators to fields using [`validated_field`]. A validator is a callable that takes a single argument and raises an exception if the value is invalid.
6077

6178
```python
79+
from dataclasses import dataclass
80+
from huggingface_hub.dataclasses import strict, validated_field
81+
6282
def multiple_of_64(value: int):
6383
if value % 64 != 0:
6484
raise ValueError(f"Value must be a multiple of 64, got {value}")
@@ -76,6 +96,9 @@ In this example, both validators are applied to the `hidden_size` field.
7696
By default, strict dataclasses only accept fields defined in the class. You can allow additional keyword arguments by setting `accept_kwargs=True` in the `@strict` decorator.
7797

7898
```python
99+
from dataclasses import dataclass
100+
from huggingface_hub.dataclasses import strict
101+
79102
@strict(accept_kwargs=True)
80103
@dataclass
81104
class ConfigWithKwargs:
@@ -94,6 +117,8 @@ Strict dataclasses respect type hints and validate them automatically. For examp
94117

95118
```python
96119
from typing import List
120+
from dataclasses import dataclass
121+
from huggingface_hub.dataclasses import strict
97122

98123
@strict
99124
@dataclass
@@ -116,6 +141,42 @@ Supported types include:
116141

117142
And any combination of these types.
118143

144+
### Class validators
145+
146+
Methods named `validate_xxx` are treated as class validators. These methods must only take `self` as an argument. Class validators are run once during initialization, right after `__post_init__`. You can define as many of them as needed—they'll be executed sequentially in the order they appear.
147+
148+
Note that class validators are not automatically re-run when a field is updated after initialization. To manually re-validate the object, you need to call `obj.validate()`.
149+
150+
```py
151+
from dataclasses import dataclass
152+
from huggingface_hub.dataclasses import strict
153+
154+
@strict
155+
@dataclass
156+
class Config:
157+
foo: str
158+
foo_length: int
159+
upper_case: bool = False
160+
161+
def validate_foo_length(self):
162+
if len(self.foo) != self.foo_length:
163+
raise ValueError(f"foo must be {self.foo_length} characters long, got {len(self.foo)}")
164+
165+
def validate_foo_casing(self):
166+
if self.upper_case and self.foo.upper() != self.foo:
167+
raise ValueError(f"foo must be uppercase, got {self.foo}")
168+
169+
config = Config(foo="bar", foo_length=3) # ok
170+
171+
config.upper_case = True
172+
config.validate() # Raises StrictDataclassFieldValidationError
173+
174+
Config(foo="abcd", foo_length=3) # Raises StrictDataclassFieldValidationError
175+
Config(foo="Bar", foo_length=3, upper_case=True) # Raises StrictDataclassFieldValidationError
176+
177+
178+
```
179+
119180
## API Reference
120181

121182
### `@strict`

src/huggingface_hub/dataclasses.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
overload,
1818
)
1919

20-
from .errors import StrictDataclassDefinitionError, StrictDataclassFieldValidationError
20+
from .errors import (
21+
StrictDataclassClassValidationError,
22+
StrictDataclassDefinitionError,
23+
StrictDataclassFieldValidationError,
24+
)
2125

2226

2327
Validator_T = Callable[[Any], None]
@@ -174,6 +178,46 @@ def __repr__(self) -> str:
174178

175179
cls.__repr__ = __repr__ # type: ignore [method-assign]
176180

181+
# List all public methods starting with `validate_` => class validators.
182+
class_validators = []
183+
if hasattr(cls, "__class_validators__"):
184+
# If inheriting from a class with class validators, add them to the list
185+
class_validators += cls.__class_validators__ # type: ignore [attr-defined]
186+
187+
for name, method in cls.__dict__.items():
188+
if name.startswith("validate_") and callable(method):
189+
if len(inspect.signature(method).parameters) != 1:
190+
raise StrictDataclassDefinitionError(
191+
f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
192+
" Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
193+
" are considered to be class validators."
194+
)
195+
class_validators.append(method)
196+
197+
cls.__class_validators__ = class_validators # type: ignore [attr-defined]
198+
199+
# Add `validate` method to the class
200+
def validate(self: T) -> None:
201+
"""Run class validators on the instance."""
202+
for validator in cls.__class_validators__: # type: ignore [attr-defined]
203+
try:
204+
validator(self)
205+
except (ValueError, TypeError) as e:
206+
raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
207+
208+
cls.validate = validate # type: ignore
209+
210+
# Run class validators after initialization
211+
initial_init = cls.__init__
212+
213+
@wraps(initial_init)
214+
def init_with_validate(self, *args, **kwargs) -> None:
215+
"""Run class validators after initialization."""
216+
initial_init(self, *args, **kwargs) # type: ignore [call-arg]
217+
cls.validate(self) # type: ignore [attr-defined]
218+
219+
setattr(cls, "__init__", init_with_validate)
220+
177221
return cls
178222

179223
# Return wrapped class or the decorator itself
@@ -441,6 +485,7 @@ def _is_validator(validator: Any) -> bool:
441485
"strict",
442486
"validated_field",
443487
"Validator_T",
488+
"StrictDataclassClassValidationError",
444489
"StrictDataclassDefinitionError",
445490
"StrictDataclassFieldValidationError",
446491
]

src/huggingface_hub/errors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,15 @@ def __init__(self, field: str, cause: Exception):
349349
super().__init__(error_message)
350350

351351

352+
class StrictDataclassClassValidationError(StrictDataclassError):
353+
"""Exception thrown when a strict dataclass fails validation on a class validator."""
354+
355+
def __init__(self, validator: str, cause: Exception):
356+
error_message = f"Class validation error for validator '{validator}':"
357+
error_message += f"\n {cause.__class__.__name__}: {cause}"
358+
super().__init__(error_message)
359+
360+
352361
# XET ERRORS
353362

354363

tests/test_utils_strict_dataclass.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import pytest
77

88
from huggingface_hub.dataclasses import _is_validator, as_validated_field, strict, type_validator, validated_field
9-
from huggingface_hub.errors import StrictDataclassDefinitionError, StrictDataclassFieldValidationError
9+
from huggingface_hub.errors import (
10+
StrictDataclassClassValidationError,
11+
StrictDataclassDefinitionError,
12+
StrictDataclassFieldValidationError,
13+
)
1014

1115

1216
def positive_int(value: int):
@@ -441,3 +445,96 @@ def test_strict_requires_dataclass():
441445
@strict
442446
class InvalidConfig:
443447
model_type: str
448+
449+
450+
class TestClassValidation:
451+
@strict
452+
@dataclass
453+
class ParentConfig:
454+
foo: str = "bar"
455+
foo_length: int = 3
456+
457+
def validate_foo_length(self):
458+
if len(self.foo) != self.foo_length:
459+
raise ValueError(f"foo must be {self.foo_length} characters long, got {len(self.foo)}")
460+
461+
@strict
462+
@dataclass
463+
class ChildConfig(ParentConfig):
464+
number: int = 42
465+
466+
def validate_number_multiple_of_foo_length(self):
467+
if self.number % self.foo_length != 0:
468+
raise ValueError(f"number must be a multiple of foo_length ({self.foo_length}), got {self.number}")
469+
470+
@strict
471+
@dataclass
472+
class OtherChildConfig(ParentConfig):
473+
number: int = 42
474+
475+
@strict
476+
@dataclass
477+
class ChildConfigWithPostInit(ParentConfig):
478+
def __post_init__(self):
479+
# Let's assume post_init doubles each value
480+
# Validation is ran AFTER __post_init__
481+
self.foo = self.foo * 2
482+
self.foo_length = self.foo_length * 2
483+
484+
def test_parent_config_validation(self):
485+
# Test valid initialization
486+
config = self.ParentConfig(foo="bar", foo_length=3)
487+
assert config.foo == "bar"
488+
assert config.foo_length == 3
489+
490+
# Test invalid initialization
491+
with pytest.raises(StrictDataclassClassValidationError):
492+
self.ParentConfig(foo="bar", foo_length=4)
493+
494+
def test_child_config_validation(self):
495+
# Test valid initialization
496+
config = self.ChildConfig(foo="bar", foo_length=3, number=42)
497+
assert config.foo == "bar"
498+
assert config.foo_length == 3
499+
assert config.number == 42
500+
501+
# Test invalid initialization
502+
with pytest.raises(StrictDataclassClassValidationError):
503+
self.ChildConfig(foo="bar", foo_length=4, number=40)
504+
505+
with pytest.raises(StrictDataclassClassValidationError):
506+
self.ChildConfig(foo="bar", foo_length=3, number=43)
507+
508+
def test_other_child_config_validation(self):
509+
# Test valid initialization
510+
config = self.OtherChildConfig(foo="bar", foo_length=3, number=43)
511+
assert config.foo == "bar"
512+
assert config.foo_length == 3
513+
assert config.number == 43 # not validated => did not fail
514+
515+
# Test invalid initialization
516+
with pytest.raises(StrictDataclassClassValidationError):
517+
self.OtherChildConfig(foo="bar", foo_length=4, number=42)
518+
519+
def test_validate_after_init(self):
520+
# Test valid initialization
521+
config = self.ParentConfig(foo="bar", foo_length=3)
522+
523+
# Attributes can be updated after initialization
524+
config.foo = "abcd"
525+
config.foo_length = 4
526+
config.validate() # Explicit call required
527+
528+
# Explicit validation fails
529+
config.foo_length = 5
530+
with pytest.raises(StrictDataclassClassValidationError):
531+
config.validate()
532+
533+
def test_validation_runs_after_post_init(self):
534+
config = self.ChildConfigWithPostInit(foo="bar", foo_length=3)
535+
assert config.foo == "barbar"
536+
assert config.foo_length == 6
537+
538+
with pytest.raises(StrictDataclassClassValidationError, match="foo must be 4 characters long, got 6"):
539+
# post init doubles the value and then the validation fails
540+
self.ChildConfigWithPostInit(foo="bar", foo_length=2)

0 commit comments

Comments
 (0)