@@ -2227,6 +2227,7 @@ def from_pretrained(
2227
2227
cls ,
2228
2228
repo_id : str ,
2229
2229
filename : Optional [str ],
2230
+ additional_files : Optional [List ] = None ,
2230
2231
local_dir : Optional [Union [str , os .PathLike [str ]]] = None ,
2231
2232
local_dir_use_symlinks : Union [bool , Literal ["auto" ]] = "auto" ,
2232
2233
cache_dir : Optional [Union [str , os .PathLike [str ]]] = None ,
@@ -2239,6 +2240,7 @@ def from_pretrained(
2239
2240
Args:
2240
2241
repo_id: The model repo id.
2241
2242
filename: A filename or glob pattern to match the model file in the repo.
2243
+ additional_files: A list of filenames or glob patterns to match additional model files in the repo.
2242
2244
local_dir: The local directory to save the model to.
2243
2245
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
2244
2246
**kwargs: Additional keyword arguments to pass to the Llama constructor.
@@ -2269,6 +2271,7 @@ def from_pretrained(
2269
2271
rel_path = Path (file ).relative_to (repo_id )
2270
2272
file_list .append (str (rel_path ))
2271
2273
2274
+ # find the only/first shard file:
2272
2275
matching_files = [file for file in file_list if fnmatch .fnmatch (file , filename )] # type: ignore
2273
2276
2274
2277
if len (matching_files ) == 0 :
@@ -2298,6 +2301,35 @@ def from_pretrained(
2298
2301
cache_dir = cache_dir ,
2299
2302
)
2300
2303
2304
+ if additional_files :
2305
+ for additonal_file_name in additional_files :
2306
+ # find the additional shard file:
2307
+ matching_additional_files = [file for file in file_list if fnmatch .fnmatch (file , additonal_file_name )]
2308
+
2309
+ if len (matching_additional_files ) == 0 :
2310
+ raise ValueError (
2311
+ f"No file found in { repo_id } that match { additonal_file_name } \n \n "
2312
+ f"Available Files:\n { json .dumps (file_list )} "
2313
+ )
2314
+
2315
+ if len (matching_additional_files ) > 1 :
2316
+ raise ValueError (
2317
+ f"Multiple files found in { repo_id } matching { additonal_file_name } \n \n "
2318
+ f"Available Files:\n { json .dumps (files )} "
2319
+ )
2320
+
2321
+ (matching_additional_file ,) = matching_additional_files
2322
+
2323
+ # download the additional file
2324
+ hf_hub_download (
2325
+ repo_id = repo_id ,
2326
+ filename = matching_additional_file ,
2327
+ subfolder = subfolder ,
2328
+ local_dir = local_dir ,
2329
+ local_dir_use_symlinks = local_dir_use_symlinks ,
2330
+ cache_dir = cache_dir ,
2331
+ )
2332
+
2301
2333
if local_dir is None :
2302
2334
model_path = hf_hub_download (
2303
2335
repo_id = repo_id ,
@@ -2311,6 +2343,7 @@ def from_pretrained(
2311
2343
else :
2312
2344
model_path = os .path .join (local_dir , filename )
2313
2345
2346
+ # loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
2314
2347
return cls (
2315
2348
model_path = model_path ,
2316
2349
** kwargs ,
0 commit comments