@@ -614,6 +614,48 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
614
614
return ggml_fp32_to_fp16 (*reinterpret_cast <const float *>(&result));
615
615
}
616
616
617
+
618
+ uint16_t f8_e5m2_to_f16 (uint8_t fp8) {
619
+ uint8_t sign = (fp8 >> 7 ) & 0x1 ;
620
+ uint8_t exponent = (fp8 >> 2 ) & 0x1F ;
621
+ uint8_t mantissa = fp8 & 0x3 ;
622
+
623
+ uint16_t fp16_sign = sign << 15 ;
624
+ uint16_t fp16_exponent;
625
+ uint16_t fp16_mantissa;
626
+
627
+ if (exponent == 0 && mantissa == 0 ) { // zero
628
+ return fp16_sign;
629
+ }
630
+
631
+ if (exponent == 0x1F ) { // NAN and INF
632
+ fp16_exponent = 0x1F ;
633
+ fp16_mantissa = mantissa ? (mantissa << 8 ) : 0 ;
634
+ return fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
635
+ }
636
+
637
+ if (exponent == 0 ) { // subnormal numbers
638
+ fp16_exponent = 0 ;
639
+ fp16_mantissa = (mantissa << 8 );
640
+ return fp16_sign | fp16_mantissa;
641
+ }
642
+
643
+ // normal numbers
644
+ int16_t true_exponent = (int16_t )exponent - 15 + 15 ;
645
+ if (true_exponent <= 0 ) {
646
+ fp16_exponent = 0 ;
647
+ fp16_mantissa = (mantissa << 8 );
648
+ } else if (true_exponent >= 0x1F ) {
649
+ fp16_exponent = 0x1F ;
650
+ fp16_mantissa = 0 ;
651
+ } else {
652
+ fp16_exponent = (uint16_t )true_exponent;
653
+ fp16_mantissa = mantissa << 8 ;
654
+ }
655
+
656
+ return fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
657
+ }
658
+
617
659
void bf16_to_f32_vec (uint16_t * src, float * dst, int64_t n) {
618
660
// support inplace op
619
661
for (int64_t i = n - 1 ; i >= 0 ; i--) {
@@ -627,6 +669,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
627
669
dst[i] = f8_e4m3_to_f16 (src[i]);
628
670
}
629
671
}
672
+ void f8_e5m2_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
673
+ // support inplace op
674
+ for (int64_t i = n - 1 ; i >= 0 ; i--) {
675
+ dst[i] = f8_e5m2_to_f16 (src[i]);
676
+ }
677
+ }
630
678
631
679
void convert_tensor (void * src,
632
680
ggml_type src_type,
@@ -863,6 +911,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
863
911
ttype = GGML_TYPE_F32;
864
912
} else if (dtype == " F8_E4M3" ) {
865
913
ttype = GGML_TYPE_F16;
914
+ } else if (dtype == " F8_E5M2" ) {
915
+ ttype = GGML_TYPE_F16;
866
916
}
867
917
return ttype;
868
918
}
@@ -976,6 +1026,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
976
1026
tensor_storage.is_f8_e4m3 = true ;
977
1027
// f8 -> f16
978
1028
GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1029
+ } else if (dtype == " F8_E5M2" ) {
1030
+ tensor_storage.is_f8_e5m2 = true ;
1031
+ // f8 -> f16
1032
+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
979
1033
} else {
980
1034
GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size);
981
1035
}
@@ -1644,6 +1698,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1644
1698
} else if (tensor_storage.is_f8_e4m3 ) {
1645
1699
// inplace op
1646
1700
f8_e4m3_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
1701
+ } else if (tensor_storage.is_f8_e5m2 ) {
1702
+ // inplace op
1703
+ f8_e5m2_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
1647
1704
}
1648
1705
} else {
1649
1706
read_buffer.resize (tensor_storage.nbytes ());
@@ -1655,6 +1712,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1655
1712
} else if (tensor_storage.is_f8_e4m3 ) {
1656
1713
// inplace op
1657
1714
f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1715
+ } else if (tensor_storage.is_f8_e5m2 ) {
1716
+ // inplace op
1717
+ f8_e5m2_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1658
1718
}
1659
1719
1660
1720
convert_tensor ((void *)read_buffer.data (), tensor_storage.type , dst_tensor->data ,
@@ -1670,6 +1730,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1670
1730
} else if (tensor_storage.is_f8_e4m3 ) {
1671
1731
// inplace op
1672
1732
f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1733
+ } else if (tensor_storage.is_f8_e5m2 ) {
1734
+ // inplace op
1735
+ f8_e5m2_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1673
1736
}
1674
1737
1675
1738
if (tensor_storage.type == dst_tensor->type ) {
0 commit comments