|
6 | 6 | import pytest |
7 | 7 |
|
8 | 8 | from huggingface_hub.dataclasses import _is_validator, as_validated_field, strict, type_validator, validated_field |
9 | | -from huggingface_hub.errors import StrictDataclassDefinitionError, StrictDataclassFieldValidationError |
| 9 | +from huggingface_hub.errors import ( |
| 10 | + StrictDataclassClassValidationError, |
| 11 | + StrictDataclassDefinitionError, |
| 12 | + StrictDataclassFieldValidationError, |
| 13 | +) |
10 | 14 |
|
11 | 15 |
|
12 | 16 | def positive_int(value: int): |
@@ -441,3 +445,96 @@ def test_strict_requires_dataclass(): |
441 | 445 | @strict |
442 | 446 | class InvalidConfig: |
443 | 447 | model_type: str |
| 448 | + |
| 449 | + |
| 450 | +class TestClassValidation: |
| 451 | + @strict |
| 452 | + @dataclass |
| 453 | + class ParentConfig: |
| 454 | + foo: str = "bar" |
| 455 | + foo_length: int = 3 |
| 456 | + |
| 457 | + def validate_foo_length(self): |
| 458 | + if len(self.foo) != self.foo_length: |
| 459 | + raise ValueError(f"foo must be {self.foo_length} characters long, got {len(self.foo)}") |
| 460 | + |
| 461 | + @strict |
| 462 | + @dataclass |
| 463 | + class ChildConfig(ParentConfig): |
| 464 | + number: int = 42 |
| 465 | + |
| 466 | + def validate_number_multiple_of_foo_length(self): |
| 467 | + if self.number % self.foo_length != 0: |
| 468 | + raise ValueError(f"number must be a multiple of foo_length ({self.foo_length}), got {self.number}") |
| 469 | + |
| 470 | + @strict |
| 471 | + @dataclass |
| 472 | + class OtherChildConfig(ParentConfig): |
| 473 | + number: int = 42 |
| 474 | + |
| 475 | + @strict |
| 476 | + @dataclass |
| 477 | + class ChildConfigWithPostInit(ParentConfig): |
| 478 | + def __post_init__(self): |
| 479 | + # Let's assume post_init doubles each value |
| 480 | + # Validation is ran AFTER __post_init__ |
| 481 | + self.foo = self.foo * 2 |
| 482 | + self.foo_length = self.foo_length * 2 |
| 483 | + |
| 484 | + def test_parent_config_validation(self): |
| 485 | + # Test valid initialization |
| 486 | + config = self.ParentConfig(foo="bar", foo_length=3) |
| 487 | + assert config.foo == "bar" |
| 488 | + assert config.foo_length == 3 |
| 489 | + |
| 490 | + # Test invalid initialization |
| 491 | + with pytest.raises(StrictDataclassClassValidationError): |
| 492 | + self.ParentConfig(foo="bar", foo_length=4) |
| 493 | + |
| 494 | + def test_child_config_validation(self): |
| 495 | + # Test valid initialization |
| 496 | + config = self.ChildConfig(foo="bar", foo_length=3, number=42) |
| 497 | + assert config.foo == "bar" |
| 498 | + assert config.foo_length == 3 |
| 499 | + assert config.number == 42 |
| 500 | + |
| 501 | + # Test invalid initialization |
| 502 | + with pytest.raises(StrictDataclassClassValidationError): |
| 503 | + self.ChildConfig(foo="bar", foo_length=4, number=40) |
| 504 | + |
| 505 | + with pytest.raises(StrictDataclassClassValidationError): |
| 506 | + self.ChildConfig(foo="bar", foo_length=3, number=43) |
| 507 | + |
| 508 | + def test_other_child_config_validation(self): |
| 509 | + # Test valid initialization |
| 510 | + config = self.OtherChildConfig(foo="bar", foo_length=3, number=43) |
| 511 | + assert config.foo == "bar" |
| 512 | + assert config.foo_length == 3 |
| 513 | + assert config.number == 43 # not validated => did not fail |
| 514 | + |
| 515 | + # Test invalid initialization |
| 516 | + with pytest.raises(StrictDataclassClassValidationError): |
| 517 | + self.OtherChildConfig(foo="bar", foo_length=4, number=42) |
| 518 | + |
| 519 | + def test_validate_after_init(self): |
| 520 | + # Test valid initialization |
| 521 | + config = self.ParentConfig(foo="bar", foo_length=3) |
| 522 | + |
| 523 | + # Attributes can be updated after initialization |
| 524 | + config.foo = "abcd" |
| 525 | + config.foo_length = 4 |
| 526 | + config.validate() # Explicit call required |
| 527 | + |
| 528 | + # Explicit validation fails |
| 529 | + config.foo_length = 5 |
| 530 | + with pytest.raises(StrictDataclassClassValidationError): |
| 531 | + config.validate() |
| 532 | + |
| 533 | + def test_validation_runs_after_post_init(self): |
| 534 | + config = self.ChildConfigWithPostInit(foo="bar", foo_length=3) |
| 535 | + assert config.foo == "barbar" |
| 536 | + assert config.foo_length == 6 |
| 537 | + |
| 538 | + with pytest.raises(StrictDataclassClassValidationError, match="foo must be 4 characters long, got 6"): |
| 539 | + # post init doubles the value and then the validation fails |
| 540 | + self.ChildConfigWithPostInit(foo="bar", foo_length=2) |
0 commit comments