Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Adds support for loading F8 e5m2 weights #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2024

Conversation

LostRuins
Copy link
Contributor

Adds support for loading F8 e5m2 weights, which is an alternative of f8 e4m3. Added in the similar manner as #359 by @Green-Sky converting to f16.

Tested and seems to work.

Comment on lines +618 to +658
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> 2) & 0x1F;
uint8_t mantissa = fp8 & 0x3;

uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent;
uint16_t fp16_mantissa;

if (exponent == 0 && mantissa == 0) { //zero
return fp16_sign;
}

if (exponent == 0x1F) { //NAN and INF
fp16_exponent = 0x1F;
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

if (exponent == 0) { //subnormal numbers
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
return fp16_sign | fp16_mantissa;
}

//normal numbers
int16_t true_exponent = (int16_t)exponent - 15 + 15;
if (true_exponent <= 0) {
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
} else if (true_exponent >= 0x1F) {
fp16_exponent = 0x1F;
fp16_mantissa = 0;
} else {
fp16_exponent = (uint16_t)true_exponent;
fp16_mantissa = mantissa << 8;
}

return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be mistaken, but can't this whole thing be replaced with just return (uint16_t)fp8<<8;, since fp8_e5m2 is basically truncated fp16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or rather return static_cast<uint16_t>(fp8) << 8;)

@leejet leejet merged commit 8f94efa into leejet:master Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants