Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Adds support for loading F8 e5m2 weights #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,48 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
}


uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> 2) & 0x1F;
uint8_t mantissa = fp8 & 0x3;

uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent;
uint16_t fp16_mantissa;

if (exponent == 0 && mantissa == 0) { //zero
return fp16_sign;
}

if (exponent == 0x1F) { //NAN and INF
fp16_exponent = 0x1F;
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

if (exponent == 0) { //subnormal numbers
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
return fp16_sign | fp16_mantissa;
}

//normal numbers
int16_t true_exponent = (int16_t)exponent - 15 + 15;
if (true_exponent <= 0) {
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
} else if (true_exponent >= 0x1F) {
fp16_exponent = 0x1F;
fp16_mantissa = 0;
} else {
fp16_exponent = (uint16_t)true_exponent;
fp16_mantissa = mantissa << 8;
}

return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

Comment on lines +618 to +658
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be mistaken, but can't this whole thing be replaced with just return (uint16_t)fp8<<8;, since fp8_e5m2 is basically truncated fp16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or rather return static_cast<uint16_t>(fp8) << 8;)

void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
Expand All @@ -627,6 +669,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
dst[i] = f8_e4m3_to_f16(src[i]);
}
}
void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
dst[i] = f8_e5m2_to_f16(src[i]);
}
}

void convert_tensor(void* src,
ggml_type src_type,
Expand Down Expand Up @@ -863,6 +911,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
ttype = GGML_TYPE_F32;
} else if (dtype == "F8_E4M3") {
ttype = GGML_TYPE_F16;
} else if (dtype == "F8_E5M2") {
ttype = GGML_TYPE_F16;
}
return ttype;
}
Expand Down Expand Up @@ -976,6 +1026,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
tensor_storage.is_f8_e4m3 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E5M2") {
tensor_storage.is_f8_e5m2 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else {
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
}
Expand Down Expand Up @@ -1629,6 +1683,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
Expand All @@ -1640,6 +1697,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
}

convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
Expand All @@ -1655,6 +1715,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
}

if (tensor_storage.type == dst_tensor->type) {
Expand Down
5 changes: 4 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct TensorStorage {
ggml_type type = GGML_TYPE_F32;
bool is_bf16 = false;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
int n_dims = 0;

Expand Down Expand Up @@ -63,7 +64,7 @@ struct TensorStorage {
}

int64_t nbytes_to_read() const {
if (is_bf16 || is_f8_e4m3) {
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
return nbytes() / 2;
} else {
return nbytes();
Expand Down Expand Up @@ -113,6 +114,8 @@ struct TensorStorage {
type_name = "bf16";
} else if (is_f8_e4m3) {
type_name = "f8_e4m3";
} else if (is_f8_e5m2) {
type_name = "f8_e5m2";
}
ss << name << " | " << type_name << " | ";
ss << n_dims << " [";
Expand Down