Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8f94efa

Browse files
authored
feat: add support for loading F8_E5M2 weights (leejet#460)
1 parent 0758544 commit 8f94efa

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

model.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,48 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
614614
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
615615
}
616616

617+
618+
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
619+
uint8_t sign = (fp8 >> 7) & 0x1;
620+
uint8_t exponent = (fp8 >> 2) & 0x1F;
621+
uint8_t mantissa = fp8 & 0x3;
622+
623+
uint16_t fp16_sign = sign << 15;
624+
uint16_t fp16_exponent;
625+
uint16_t fp16_mantissa;
626+
627+
if (exponent == 0 && mantissa == 0) { //zero
628+
return fp16_sign;
629+
}
630+
631+
if (exponent == 0x1F) { //NAN and INF
632+
fp16_exponent = 0x1F;
633+
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
634+
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
635+
}
636+
637+
if (exponent == 0) { //subnormal numbers
638+
fp16_exponent = 0;
639+
fp16_mantissa = (mantissa << 8);
640+
return fp16_sign | fp16_mantissa;
641+
}
642+
643+
//normal numbers
644+
int16_t true_exponent = (int16_t)exponent - 15 + 15;
645+
if (true_exponent <= 0) {
646+
fp16_exponent = 0;
647+
fp16_mantissa = (mantissa << 8);
648+
} else if (true_exponent >= 0x1F) {
649+
fp16_exponent = 0x1F;
650+
fp16_mantissa = 0;
651+
} else {
652+
fp16_exponent = (uint16_t)true_exponent;
653+
fp16_mantissa = mantissa << 8;
654+
}
655+
656+
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
657+
}
658+
617659
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
618660
// support inplace op
619661
for (int64_t i = n - 1; i >= 0; i--) {
@@ -627,6 +669,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
627669
dst[i] = f8_e4m3_to_f16(src[i]);
628670
}
629671
}
672+
void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
673+
// support inplace op
674+
for (int64_t i = n - 1; i >= 0; i--) {
675+
dst[i] = f8_e5m2_to_f16(src[i]);
676+
}
677+
}
630678

631679
void convert_tensor(void* src,
632680
ggml_type src_type,
@@ -863,6 +911,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
863911
ttype = GGML_TYPE_F32;
864912
} else if (dtype == "F8_E4M3") {
865913
ttype = GGML_TYPE_F16;
914+
} else if (dtype == "F8_E5M2") {
915+
ttype = GGML_TYPE_F16;
866916
}
867917
return ttype;
868918
}
@@ -976,6 +1026,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
9761026
tensor_storage.is_f8_e4m3 = true;
9771027
// f8 -> f16
9781028
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1029+
} else if (dtype == "F8_E5M2") {
1030+
tensor_storage.is_f8_e5m2 = true;
1031+
// f8 -> f16
1032+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
9791033
} else {
9801034
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
9811035
}
@@ -1644,6 +1698,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
16441698
} else if (tensor_storage.is_f8_e4m3) {
16451699
// inplace op
16461700
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
1701+
} else if (tensor_storage.is_f8_e5m2) {
1702+
// inplace op
1703+
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
16471704
}
16481705
} else {
16491706
read_buffer.resize(tensor_storage.nbytes());
@@ -1655,6 +1712,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
16551712
} else if (tensor_storage.is_f8_e4m3) {
16561713
// inplace op
16571714
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
1715+
} else if (tensor_storage.is_f8_e5m2) {
1716+
// inplace op
1717+
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
16581718
}
16591719

16601720
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
@@ -1670,6 +1730,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
16701730
} else if (tensor_storage.is_f8_e4m3) {
16711731
// inplace op
16721732
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
1733+
} else if (tensor_storage.is_f8_e5m2) {
1734+
// inplace op
1735+
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
16731736
}
16741737

16751738
if (tensor_storage.type == dst_tensor->type) {

model.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct TensorStorage {
3636
ggml_type type = GGML_TYPE_F32;
3737
bool is_bf16 = false;
3838
bool is_f8_e4m3 = false;
39+
bool is_f8_e5m2 = false;
3940
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
4041
int n_dims = 0;
4142

@@ -65,7 +66,7 @@ struct TensorStorage {
6566
}
6667

6768
int64_t nbytes_to_read() const {
68-
if (is_bf16 || is_f8_e4m3) {
69+
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
6970
return nbytes() / 2;
7071
} else {
7172
return nbytes();
@@ -115,6 +116,8 @@ struct TensorStorage {
115116
type_name = "bf16";
116117
} else if (is_f8_e4m3) {
117118
type_name = "f8_e4m3";
119+
} else if (is_f8_e5m2) {
120+
type_name = "f8_e5m2";
118121
}
119122
ss << name << " | " << type_name << " | ";
120123
ss << n_dims << " [";

0 commit comments

Comments
 (0)