@@ -222,8 +222,8 @@ def _set_base_config_defaults(self, **kwargs):
222222 pre-`@strict` and `@strict`-powered config classes to co-exist: the former pass `**kwargs` to `__init__`
223223 and the latter sets all atributes in the class directly before calling this function.
224224
225- Note that we need to keep pre-`@strict` support for backwards compatibility with custom config classes from the
226- Hub.
225+ Note that we need to keep pre-`@strict` support for backwards compatibility with custom config classes from
226+ the Hub.
227227 """
228228
229229 def _default_if_unset (attribute_name , default_value ):
@@ -263,11 +263,7 @@ def _default_if_unset(attribute_name, default_value):
263263 self .finetuning_task = _default_if_unset ("finetuning_task" , None )
264264 self .id2label = _default_if_unset ("id2label" , None )
265265 self .label2id = _default_if_unset ("label2id" , None )
266- if self .label2id is not None and not isinstance (self .label2id , dict ):
267- raise ValueError ("Argument label2id should be a dictionary." )
268266 if self .id2label is not None :
269- if not isinstance (self .id2label , dict ):
270- raise ValueError ("Argument id2label should be a dictionary." )
271267 num_labels = kwargs .pop ("num_labels" , None )
272268 if num_labels is not None and len (self .id2label ) != num_labels :
273269 logger .warning (
@@ -302,19 +298,6 @@ def _default_if_unset(attribute_name, default_value):
302298
303299 # regression / multi-label classification
304300 self .problem_type = _default_if_unset ("problem_type" , None )
305- allowed_problem_types = ("regression" , "single_label_classification" , "multi_label_classification" )
306- if self .problem_type is not None and self .problem_type not in allowed_problem_types :
307- raise ValueError (
308- f"The config parameter `problem_type` was not understood: received { self .problem_type } "
309- "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
310- )
311-
312- # TPU arguments
313- if _default_if_unset ("xla_device" , None ) is not None :
314- logger .warning (
315- "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
316- "safely remove it from your `config.json` file."
317- )
318301
319302 # Name or path to the pretrained checkpoint
320303 self ._name_or_path = str (_default_if_unset ("name_or_path" , "" ))
@@ -344,84 +327,60 @@ def _default_if_unset(attribute_name, default_value):
344327 logger .error (f"Can't set { key } with value { value } for { self } " )
345328 raise err
346329
347- # Finally, resets class-level properties (setters and getters)
348- self ._reset_properties ()
330+ # Finally, builds additional general attributes (these used to be properties)
331+ self .name_or_path = getattr (self , "_name_or_path" , None )
332+ self .output_attentions = getattr (self , "_output_attentions" , False )
333+ self .use_return_dict = self .return_dict and not self .torchscript
334+ self .num_labels = getattr (self , "_num_labels" , None ) or len (self .id2label )
335+ if self .id2label is None or len (self .id2label ) != self .num_labels :
336+ self .id2label = {i : f"LABEL_{ i } " for i in range (self .num_labels )}
337+ self .label2id = dict (zip (self .id2label .values (), self .id2label .keys ()))
338+
339+ # Manually call validation, for non-`@strict`-powered classes
340+ self .validate_base_pretrained_config_attributes ()
341+
342+ @property
343+ def _attn_implementation (self ):
344+ # This property is made private for now (as it cannot be changed and a
345+ # PreTrainedModel.use_attn_implementation method needs to be implemented.)
346+ if hasattr (self , "_attn_implementation_internal" ):
347+ if self ._attn_implementation_internal is None :
348+ # `config.attn_implementation` should never be None, for backward compatibility.
349+ return "eager"
350+ else :
351+ return self ._attn_implementation_internal
352+ else :
353+ return "eager"
354+
355+ @_attn_implementation .setter
356+ def _attn_implementation (self , value ):
357+ self ._attn_implementation_internal = value
349358
350- def _reset_properties (self ):
359+ def validate_base_pretrained_config_attributes (self ):
351360 """
352- BC: some old attribute names share the name with newer class-level properties. `@strict`-powered classes may
353- overwrite these properties in the instance at __init__ time, so we have to reset them .
361+ Part of `@strict`-powered validation. Validates the contents of most attributes in `PretrainedConfig`. These
362+ checks used to be part of __init__.
354363 """
364+ # `output_attentions`
365+ if self .output_attentions and self ._attn_implementation != "eager" :
366+ raise ValueError (
367+ "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
368+ f"{ self ._attn_implementation } . Please set it to 'eager' instead."
369+ )
355370
356- # Define all wanted properties as usual
357- @property
358- def name_or_path (self ) -> str :
359- return getattr (self , "_name_or_path" , None )
360-
361- @name_or_path .setter
362- def name_or_path (self , value ):
363- self ._name_or_path = str (value ) # Make sure that name_or_path is a string (for JSON encoding)
364-
365- @property
366- def output_attentions (self ) -> bool :
367- """
368- `bool`: Whether or not the model should returns all attentions.
369- """
370- return self ._output_attentions
371-
372- @output_attentions .setter
373- def output_attentions (self , value ):
374- if self ._attn_implementation != "eager" :
375- raise ValueError (
376- "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
377- f"{ self ._attn_implementation } . Please set it to 'eager' instead."
378- )
379- self ._output_attentions = value
380-
381- @property
382- def use_return_dict (self ) -> bool :
383- """
384- `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
385- """
386- # If torchscript is set, force `return_dict=False` to avoid jit errors
387- return self .return_dict and not self .torchscript
388-
389- @property
390- def num_labels (self ) -> int :
391- """
392- `int`: The number of labels for classification models.
393- """
394- return self ._num_labels or len (self .id2label )
395-
396- @num_labels .setter
397- def num_labels (self , num_labels : int ):
398- if not hasattr (self , "id2label" ) or self .id2label is None or len (self .id2label ) != num_labels :
399- self .id2label = {i : f"LABEL_{ i } " for i in range (num_labels )}
400- self .label2id = dict (zip (self .id2label .values (), self .id2label .keys ()))
401-
402- @property
403- def _attn_implementation (self ):
404- # This property is made private for now (as it cannot be changed and a
405- # PreTrainedModel.use_attn_implementation method needs to be implemented.)
406- if hasattr (self , "_attn_implementation_internal" ):
407- if self ._attn_implementation_internal is None :
408- # `config.attn_implementation` should never be None, for backward compatibility.
409- return "eager"
410- else :
411- return self ._attn_implementation_internal
412- else :
413- return "eager"
414-
415- @_attn_implementation .setter
416- def _attn_implementation (self , value ):
417- self ._attn_implementation_internal = value
371+ # `label2id`, `id2label`
372+ if self .label2id is not None and not isinstance (self .label2id , dict ):
373+ raise ValueError ("label2id should be a dictionary." )
374+ if self .id2label is not None and not isinstance (self .id2label , dict ):
375+ raise ValueError ("id2label should be a dictionary." )
418376
419- # Set them in the class (properties must be set at a class level)
420- setattr (PretrainedConfig , "name_or_path" , name_or_path )
421- setattr (PretrainedConfig , "output_attentions" , output_attentions )
422- setattr (PretrainedConfig , "use_return_dict" , use_return_dict )
423- setattr (PretrainedConfig , "num_labels" , num_labels )
424- setattr (PretrainedConfig , "_attn_implementation" , _attn_implementation )
377+ # `problem_type`
378+ allowed_problem_types = ("regression" , "single_label_classification" , "multi_label_classification" )
379+ if self .problem_type is not None and self .problem_type not in allowed_problem_types :
380+ raise ValueError (
381+ f"The config parameter `problem_type` was not understood: received { self .problem_type } "
382+ "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
383+ )
425384
426385 def validate_token_ids (self ):
427386 """Part of `@strict`-powered validation. Validates the contents of the special tokens."""
0 commit comments