Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
safe import
  • Loading branch information
tastelikefeet committed Feb 13, 2025
commit e6db19c11f65943ae23a2454d58083502d7ef1ae
14 changes: 9 additions & 5 deletions modelscope/utils/hf_util/auto_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@
else:

from .patcher import get_all_imported_modules, _patch_pretrained_class
all_available_modules = _patch_pretrained_class(
get_all_imported_modules(), wrap=True)

for module in all_available_modules:
globals()[module.__name__] = module
try:
all_available_modules = _patch_pretrained_class(
get_all_imported_modules(), wrap=True)
except Exception: # noqa
import traceback
traceback.print_exc()
else:
for module in all_available_modules:
globals()[module.__name__] = module
82 changes: 47 additions & 35 deletions modelscope/utils/hf_util/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ def get_all_imported_modules():
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast'
]
peft_include_names = ['.*PeftModel.*', '.*Config']
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
if importlib.util.find_spec('transformers') is not None:
import transformers
lazy_module = sys.modules['transformers']
_import_structure = lazy_module._import_structure
for key in _import_structure:
if 'dummy' in key.lower():
continue
values = _import_structure[key]
for value in values:
# pretrained
Expand All @@ -60,8 +63,11 @@ def get_all_imported_modules():
imports = [
attr for attr in attributes if not attr.startswith('__')
]
all_imported_modules.extend(
[getattr(peft, _import) for _import in imports])
all_imported_modules.extend([
getattr(peft, _import) for _import in imports if any([
re.fullmatch(name, _import) for name in peft_include_names
])
])

if importlib.util.find_spec('diffusers') is not None:
try:
Expand All @@ -73,6 +79,8 @@ def get_all_imported_modules():
if hasattr(lazy_module, '_import_structure'):
_import_structure = lazy_module._import_structure
for key in _import_structure:
if 'dummy' in key.lower():
continue
values = _import_structure[key]
for value in values:
if any([
Expand All @@ -91,8 +99,13 @@ def get_all_imported_modules():
imports = [
attr for attr in attributes if not attr.startswith('__')
]
all_imported_modules.extend(
[getattr(lazy_module, _import) for _import in imports])
all_imported_modules.extend([
getattr(lazy_module, _import) for _import in imports
if any([
re.fullmatch(name, _import)
for name in diffusers_include_names
])
])
return all_imported_modules


Expand Down Expand Up @@ -526,40 +539,39 @@ def create_commit(
revision=revision,
repo_type=repo_type or 'model')

def load(
cls,
repo_id_or_path: Union[str, Path],
repo_type: Optional[str] = None,
token: Optional[str] = None,
ignore_metadata_errors: bool = False,
):
from modelscope.hub.api import HubApi
api = HubApi()
api.login(token)
if os.path.exists(repo_id_or_path):
file_path = repo_id_or_path
elif repo_type == 'model' or repo_type is None:
from modelscope import model_file_download
file_path = model_file_download(repo_id_or_path, 'README.md')
elif repo_type == 'dataset':
from modelscope import dataset_file_download
file_path = dataset_file_download(repo_id_or_path, 'README.md')
else:
raise ValueError(
f'repo_type should be `model` or `dataset`, but now is {repo_type}'
)

with open(file_path, 'r') as f:
repo_card = cls(
f.read(), ignore_metadata_errors=ignore_metadata_errors)
if not hasattr(repo_card.data, 'tags'):
repo_card.data.tags = []
return repo_card

# Patch repocard.validate
from huggingface_hub import repocard
if not hasattr(repocard.RepoCard, '_validate_origin'):

def load(
cls,
repo_id_or_path: Union[str, Path],
repo_type: Optional[str] = None,
token: Optional[str] = None,
ignore_metadata_errors: bool = False,
):
from modelscope.hub.api import HubApi
api = HubApi()
api.login(token)
if os.path.exists(repo_id_or_path):
file_path = repo_id_or_path
elif repo_type == 'model' or repo_type is None:
from modelscope import model_file_download
file_path = model_file_download(repo_id_or_path, 'README.md')
elif repo_type == 'dataset':
from modelscope import dataset_file_download
file_path = dataset_file_download(repo_id_or_path, 'README.md')
else:
raise ValueError(
f'repo_type should be `model` or `dataset`, but now is {repo_type}'
)

with open(file_path, 'r') as f:
repo_card = cls(
f.read(), ignore_metadata_errors=ignore_metadata_errors)
if not hasattr(repo_card.data, 'tags'):
repo_card.data.tags = []
return repo_card

repocard.RepoCard._validate_origin = repocard.RepoCard.validate
repocard.RepoCard.validate = lambda *args, **kwargs: None
repocard.RepoCard._load_origin = repocard.RepoCard.load
Expand Down
Loading