@@ -67,6 +67,7 @@ class QuantizedDataType:
67
67
{ftype : dtype for (dtype , ftype ) in DATA_TYPE_TO_FTYPE .items ()}
68
68
69
69
DATA_TYPE_TO_NUMPY : Dict [DataType , 'np.dtype[Any]' ] = {
70
+ DT_BF16 : np .dtype (np .uint16 ),
70
71
DT_F16 : np .dtype (np .float16 ),
71
72
DT_F32 : np .dtype (np .float32 ),
72
73
DT_I32 : np .dtype (np .int32 ),
@@ -276,6 +277,12 @@ def permute(self, n_head: int) -> 'Tensor': ...
276
277
def to_ggml (self ) -> 'GGMLCompatibleTensor' : ...
277
278
278
279
280
+ def bf16_to_fp32 (bf16_arr : np .ndarray ) -> np .ndarray :
281
+ assert bf16_arr .dtype == np .uint16 , f"Input array should be of dtype uint16, but got { bf16_arr .dtype } "
282
+ fp32_arr = bf16_arr .astype (np .uint32 ) << 16
283
+ return fp32_arr .view (np .float32 )
284
+
285
+
279
286
class UnquantizedTensor (Tensor ):
280
287
def __init__ (self , ndarray : NDArray ) -> None :
281
288
assert isinstance (ndarray , np .ndarray )
@@ -284,6 +291,8 @@ def __init__(self, ndarray: NDArray) -> None:
284
291
285
292
def astype (self , data_type : DataType ) -> Tensor :
286
293
dtype = DATA_TYPE_TO_NUMPY [data_type ]
294
+ if self .data_type == DT_BF16 :
295
+ self .ndarray = bf16_to_fp32 (self .ndarray )
287
296
return UnquantizedTensor (self .ndarray .astype (dtype ))
288
297
289
298
def to_ggml (self ) -> 'UnquantizedTensor' :
@@ -686,6 +695,7 @@ def load(offset: int, elm_count: int) -> NDArray:
686
695
description = f'storage data_type={ data_type } path-in-zip={ filename } path={ self .zip_file .filename } '
687
696
return LazyStorage (load = load , kind = pid [1 ], description = description )
688
697
698
+ @staticmethod
689
699
def lazy_rebuild_tensor_v2 (storage : Any , storage_offset : Any , size : Any , stride : Any , # pyright: ignore[reportSelfClsParameterName]
690
700
requires_grad : Any , backward_hooks : Any , metadata : Any = None ) -> LazyTensor :
691
701
assert isinstance (storage , LazyStorage )
@@ -696,12 +706,18 @@ def load() -> UnquantizedTensor:
696
706
description = f'pickled storage_offset={ storage_offset } in { storage .description } '
697
707
return LazyTensor (load , list (size ), storage .kind .data_type , description )
698
708
709
+ @staticmethod
710
+ def rebuild_from_type_v2 (func , new_type , args , state ):
711
+ return func (* args )
712
+
699
713
CLASSES : Dict [Any , Any ] = {
714
+ ('torch._tensor' , '_rebuild_from_type_v2' ): rebuild_from_type_v2 ,
700
715
('torch._utils' , '_rebuild_tensor_v2' ): lazy_rebuild_tensor_v2 ,
701
716
('torch' , 'BFloat16Storage' ): LazyStorageKind (DT_BF16 ),
702
717
('torch' , 'HalfStorage' ): LazyStorageKind (DT_F16 ),
703
718
('torch' , 'FloatStorage' ): LazyStorageKind (DT_F32 ),
704
719
('torch' , 'IntStorage' ): LazyStorageKind (DT_I32 ),
720
+ ('torch' , 'Tensor' ): LazyTensor ,
705
721
}
706
722
707
723
def find_class (self , module : str , name : str ) -> Any :
@@ -961,7 +977,7 @@ def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
961
977
962
978
def pick_output_type (model : LazyModel , output_type_str : Optional [str ]) -> GGMLFileType :
963
979
wq_type = model ["layers.0.attention.wq.weight" ].data_type
964
- if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32 ):
980
+ if output_type_str == "f32" or (output_type_str is None and wq_type in ( DT_F32 , DT_BF16 ) ):
965
981
return GGMLFileType .AllF32
966
982
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16 ):
967
983
return GGMLFileType .MostlyF16
0 commit comments