-
Notifications
You must be signed in to change notification settings - Fork 371
Adds support for loading F8 e5m2 weights #460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
uint16_t f8_e5m2_to_f16(uint8_t fp8) { | ||
uint8_t sign = (fp8 >> 7) & 0x1; | ||
uint8_t exponent = (fp8 >> 2) & 0x1F; | ||
uint8_t mantissa = fp8 & 0x3; | ||
|
||
uint16_t fp16_sign = sign << 15; | ||
uint16_t fp16_exponent; | ||
uint16_t fp16_mantissa; | ||
|
||
if (exponent == 0 && mantissa == 0) { //zero | ||
return fp16_sign; | ||
} | ||
|
||
if (exponent == 0x1F) { //NAN and INF | ||
fp16_exponent = 0x1F; | ||
fp16_mantissa = mantissa ? (mantissa << 8) : 0; | ||
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; | ||
} | ||
|
||
if (exponent == 0) { //subnormal numbers | ||
fp16_exponent = 0; | ||
fp16_mantissa = (mantissa << 8); | ||
return fp16_sign | fp16_mantissa; | ||
} | ||
|
||
//normal numbers | ||
int16_t true_exponent = (int16_t)exponent - 15 + 15; | ||
if (true_exponent <= 0) { | ||
fp16_exponent = 0; | ||
fp16_mantissa = (mantissa << 8); | ||
} else if (true_exponent >= 0x1F) { | ||
fp16_exponent = 0x1F; | ||
fp16_mantissa = 0; | ||
} else { | ||
fp16_exponent = (uint16_t)true_exponent; | ||
fp16_mantissa = mantissa << 8; | ||
} | ||
|
||
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be mistaken, but can't this whole thing be replaced with just return (uint16_t)fp8<<8;
, since fp8_e5m2 is basically truncated fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(or rather return static_cast<uint16_t>(fp8) << 8;
)
Adds support for loading F8 e5m2 weights, which is an alternative of f8 e4m3. Added in the similar manner as #359 by @Green-Sky converting to f16.
Tested and seems to work.