Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 84c0920

Browse files
Gnurroabetlen
andauthored
feat: Add loading sharded GGUF files from HuggingFace with Llama.from_pretrained(additional_files=[...]) . Closes abetlen#1341
Co-authored-by: Andrei <[email protected]>
1 parent 29afcfd commit 84c0920

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

llama_cpp/llama.py

+33
Original file line numberDiff line numberDiff line change
@@ -2227,6 +2227,7 @@ def from_pretrained(
22272227
cls,
22282228
repo_id: str,
22292229
filename: Optional[str],
2230+
additional_files: Optional[List] = None,
22302231
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
22312232
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
22322233
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
@@ -2239,6 +2240,7 @@ def from_pretrained(
22392240
Args:
22402241
repo_id: The model repo id.
22412242
filename: A filename or glob pattern to match the model file in the repo.
2243+
additional_files: A list of filenames or glob patterns to match additional model files in the repo.
22422244
local_dir: The local directory to save the model to.
22432245
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
22442246
**kwargs: Additional keyword arguments to pass to the Llama constructor.
@@ -2269,6 +2271,7 @@ def from_pretrained(
22692271
rel_path = Path(file).relative_to(repo_id)
22702272
file_list.append(str(rel_path))
22712273

2274+
# find the only/first shard file:
22722275
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
22732276

22742277
if len(matching_files) == 0:
@@ -2298,6 +2301,35 @@ def from_pretrained(
22982301
cache_dir=cache_dir,
22992302
)
23002303

2304+
if additional_files:
2305+
for additonal_file_name in additional_files:
2306+
# find the additional shard file:
2307+
matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]
2308+
2309+
if len(matching_additional_files) == 0:
2310+
raise ValueError(
2311+
f"No file found in {repo_id} that match {additonal_file_name}\n\n"
2312+
f"Available Files:\n{json.dumps(file_list)}"
2313+
)
2314+
2315+
if len(matching_additional_files) > 1:
2316+
raise ValueError(
2317+
f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
2318+
f"Available Files:\n{json.dumps(files)}"
2319+
)
2320+
2321+
(matching_additional_file,) = matching_additional_files
2322+
2323+
# download the additional file
2324+
hf_hub_download(
2325+
repo_id=repo_id,
2326+
filename=matching_additional_file,
2327+
subfolder=subfolder,
2328+
local_dir=local_dir,
2329+
local_dir_use_symlinks=local_dir_use_symlinks,
2330+
cache_dir=cache_dir,
2331+
)
2332+
23012333
if local_dir is None:
23022334
model_path = hf_hub_download(
23032335
repo_id=repo_id,
@@ -2311,6 +2343,7 @@ def from_pretrained(
23112343
else:
23122344
model_path = os.path.join(local_dir, filename)
23132345

2346+
# loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
23142347
return cls(
23152348
model_path=model_path,
23162349
**kwargs,

0 commit comments

Comments
 (0)