Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cf0e3ea

Browse files
committed
no properties
1 parent 58bdea8 commit cf0e3ea

1 file changed

Lines changed: 51 additions & 92 deletions

File tree

‎src/transformers/configuration_utils.py‎

Lines changed: 51 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ def _set_base_config_defaults(self, **kwargs):
222222
pre-`@strict` and `@strict`-powered config classes to co-exist: the former pass `**kwargs` to `__init__`
223223
and the latter sets all atributes in the class directly before calling this function.
224224
225-
Note that we need to keep pre-`@strict` support for backwards compatibility with custom config classes from the
226-
Hub.
225+
Note that we need to keep pre-`@strict` support for backwards compatibility with custom config classes from
226+
the Hub.
227227
"""
228228

229229
def _default_if_unset(attribute_name, default_value):
@@ -263,11 +263,7 @@ def _default_if_unset(attribute_name, default_value):
263263
self.finetuning_task = _default_if_unset("finetuning_task", None)
264264
self.id2label = _default_if_unset("id2label", None)
265265
self.label2id = _default_if_unset("label2id", None)
266-
if self.label2id is not None and not isinstance(self.label2id, dict):
267-
raise ValueError("Argument label2id should be a dictionary.")
268266
if self.id2label is not None:
269-
if not isinstance(self.id2label, dict):
270-
raise ValueError("Argument id2label should be a dictionary.")
271267
num_labels = kwargs.pop("num_labels", None)
272268
if num_labels is not None and len(self.id2label) != num_labels:
273269
logger.warning(
@@ -302,19 +298,6 @@ def _default_if_unset(attribute_name, default_value):
302298

303299
# regression / multi-label classification
304300
self.problem_type = _default_if_unset("problem_type", None)
305-
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
306-
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
307-
raise ValueError(
308-
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
309-
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
310-
)
311-
312-
# TPU arguments
313-
if _default_if_unset("xla_device", None) is not None:
314-
logger.warning(
315-
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
316-
"safely remove it from your `config.json` file."
317-
)
318301

319302
# Name or path to the pretrained checkpoint
320303
self._name_or_path = str(_default_if_unset("name_or_path", ""))
@@ -344,84 +327,60 @@ def _default_if_unset(attribute_name, default_value):
344327
logger.error(f"Can't set {key} with value {value} for {self}")
345328
raise err
346329

347-
# Finally, resets class-level properties (setters and getters)
348-
self._reset_properties()
330+
# Finally, builds additional general attributes (these used to be properties)
331+
self.name_or_path = getattr(self, "_name_or_path", None)
332+
self.output_attentions = getattr(self, "_output_attentions", False)
333+
self.use_return_dict = self.return_dict and not self.torchscript
334+
self.num_labels = getattr(self, "_num_labels", None) or len(self.id2label)
335+
if self.id2label is None or len(self.id2label) != self.num_labels:
336+
self.id2label = {i: f"LABEL_{i}" for i in range(self.num_labels)}
337+
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
338+
339+
# Manually call validation, for non-`@strict`-powered classes
340+
self.validate_base_pretrained_config_attributes()
341+
342+
@property
343+
def _attn_implementation(self):
344+
# This property is made private for now (as it cannot be changed and a
345+
# PreTrainedModel.use_attn_implementation method needs to be implemented.)
346+
if hasattr(self, "_attn_implementation_internal"):
347+
if self._attn_implementation_internal is None:
348+
# `config.attn_implementation` should never be None, for backward compatibility.
349+
return "eager"
350+
else:
351+
return self._attn_implementation_internal
352+
else:
353+
return "eager"
354+
355+
@_attn_implementation.setter
356+
def _attn_implementation(self, value):
357+
self._attn_implementation_internal = value
349358

350-
def _reset_properties(self):
359+
def validate_base_pretrained_config_attributes(self):
351360
"""
352-
BC: some old attribute names share the name with newer class-level properties. `@strict`-powered classes may
353-
overwrite these properties in the instance at __init__ time, so we have to reset them.
361+
Part of `@strict`-powered validation. Validates the contents of most attributes in `PretrainedConfig`. These
362+
checks used to be part of __init__.
354363
"""
364+
# `output_attentions`
365+
if self.output_attentions and self._attn_implementation != "eager":
366+
raise ValueError(
367+
"The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
368+
f"{self._attn_implementation}. Please set it to 'eager' instead."
369+
)
355370

356-
# Define all wanted properties as usual
357-
@property
358-
def name_or_path(self) -> str:
359-
return getattr(self, "_name_or_path", None)
360-
361-
@name_or_path.setter
362-
def name_or_path(self, value):
363-
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
364-
365-
@property
366-
def output_attentions(self) -> bool:
367-
"""
368-
`bool`: Whether or not the model should returns all attentions.
369-
"""
370-
return self._output_attentions
371-
372-
@output_attentions.setter
373-
def output_attentions(self, value):
374-
if self._attn_implementation != "eager":
375-
raise ValueError(
376-
"The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
377-
f"{self._attn_implementation}. Please set it to 'eager' instead."
378-
)
379-
self._output_attentions = value
380-
381-
@property
382-
def use_return_dict(self) -> bool:
383-
"""
384-
`bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
385-
"""
386-
# If torchscript is set, force `return_dict=False` to avoid jit errors
387-
return self.return_dict and not self.torchscript
388-
389-
@property
390-
def num_labels(self) -> int:
391-
"""
392-
`int`: The number of labels for classification models.
393-
"""
394-
return self._num_labels or len(self.id2label)
395-
396-
@num_labels.setter
397-
def num_labels(self, num_labels: int):
398-
if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
399-
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
400-
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
401-
402-
@property
403-
def _attn_implementation(self):
404-
# This property is made private for now (as it cannot be changed and a
405-
# PreTrainedModel.use_attn_implementation method needs to be implemented.)
406-
if hasattr(self, "_attn_implementation_internal"):
407-
if self._attn_implementation_internal is None:
408-
# `config.attn_implementation` should never be None, for backward compatibility.
409-
return "eager"
410-
else:
411-
return self._attn_implementation_internal
412-
else:
413-
return "eager"
414-
415-
@_attn_implementation.setter
416-
def _attn_implementation(self, value):
417-
self._attn_implementation_internal = value
371+
# `label2id`, `id2label`
372+
if self.label2id is not None and not isinstance(self.label2id, dict):
373+
raise ValueError("label2id should be a dictionary.")
374+
if self.id2label is not None and not isinstance(self.id2label, dict):
375+
raise ValueError("id2label should be a dictionary.")
418376

419-
# Set them in the class (properties must be set at a class level)
420-
setattr(PretrainedConfig, "name_or_path", name_or_path)
421-
setattr(PretrainedConfig, "output_attentions", output_attentions)
422-
setattr(PretrainedConfig, "use_return_dict", use_return_dict)
423-
setattr(PretrainedConfig, "num_labels", num_labels)
424-
setattr(PretrainedConfig, "_attn_implementation", _attn_implementation)
377+
# `problem_type`
378+
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
379+
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
380+
raise ValueError(
381+
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
382+
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
383+
)
425384

426385
def validate_token_ids(self):
427386
"""Part of `@strict`-powered validation. Validates the contents of the special tokens."""

0 commit comments

Comments
 (0)