Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d3e8093

Browse files
convert: support DT_BF16 tensors (ggml-org#1309)
Co-authored-by: Pavol Rusnak <[email protected]>
1 parent 360cfe5 commit d3e8093

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

convert.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class QuantizedDataType:
6767
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}
6868

6969
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
70+
DT_BF16: np.dtype(np.uint16),
7071
DT_F16: np.dtype(np.float16),
7172
DT_F32: np.dtype(np.float32),
7273
DT_I32: np.dtype(np.int32),
@@ -276,6 +277,12 @@ def permute(self, n_head: int) -> 'Tensor': ...
276277
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
277278

278279

280+
def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray:
281+
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
282+
fp32_arr = bf16_arr.astype(np.uint32) << 16
283+
return fp32_arr.view(np.float32)
284+
285+
279286
class UnquantizedTensor(Tensor):
280287
def __init__(self, ndarray: NDArray) -> None:
281288
assert isinstance(ndarray, np.ndarray)
@@ -284,6 +291,8 @@ def __init__(self, ndarray: NDArray) -> None:
284291

285292
def astype(self, data_type: DataType) -> Tensor:
286293
dtype = DATA_TYPE_TO_NUMPY[data_type]
294+
if self.data_type == DT_BF16:
295+
self.ndarray = bf16_to_fp32(self.ndarray)
287296
return UnquantizedTensor(self.ndarray.astype(dtype))
288297

289298
def to_ggml(self) -> 'UnquantizedTensor':
@@ -686,6 +695,7 @@ def load(offset: int, elm_count: int) -> NDArray:
686695
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
687696
return LazyStorage(load=load, kind=pid[1], description=description)
688697

698+
@staticmethod
689699
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName]
690700
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
691701
assert isinstance(storage, LazyStorage)
@@ -696,12 +706,18 @@ def load() -> UnquantizedTensor:
696706
description = f'pickled storage_offset={storage_offset} in {storage.description}'
697707
return LazyTensor(load, list(size), storage.kind.data_type, description)
698708

709+
@staticmethod
710+
def rebuild_from_type_v2(func, new_type, args, state):
711+
return func(*args)
712+
699713
CLASSES: Dict[Any, Any] = {
714+
('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2,
700715
('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2,
701716
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
702717
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
703718
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
704719
('torch', 'IntStorage'): LazyStorageKind(DT_I32),
720+
('torch', 'Tensor'): LazyTensor,
705721
}
706722

707723
def find_class(self, module: str, name: str) -> Any:
@@ -961,7 +977,7 @@ def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
961977

962978
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
963979
wq_type = model["layers.0.attention.wq.weight"].data_type
964-
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
980+
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
965981
return GGMLFileType.AllF32
966982
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16):
967983
return GGMLFileType.MostlyF16

0 commit comments

Comments
 (0)