Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Change f32::midpoint to upcast to f64
This has been verified by kani as a correct optimization

see: #110840 (comment)

The new implementation is branchless, and only differs in which NaN
values are produced (if any are produced at all). Which is fine to change.
Aside from NaN handling, this implementation produces bitwise identical
results to the original implementation.

The new implementation is gated on targets that have a fast 64-bit
floating point implementation in hardware, and on WASM.
  • Loading branch information
RustyYato committed Jun 2, 2024
commit 849c5254afdeb96a740cfac989a382fde69b6067
55 changes: 36 additions & 19 deletions library/core/src/num/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1016,25 +1016,42 @@ impl f32 {
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
pub fn midpoint(self, other: f32) -> f32 {
const LO: f32 = f32::MIN_POSITIVE * 2.;
const HI: f32 = f32::MAX / 2.;

let (a, b) = (self, other);
let abs_a = a.abs_private();
let abs_b = b.abs_private();

if abs_a <= HI && abs_b <= HI {
// Overflow is impossible
(a + b) / 2.
} else if abs_a < LO {
// Not safe to halve a
a + (b / 2.)
} else if abs_b < LO {
// Not safe to halve b
(a / 2.) + b
} else {
// Not safe to halve a and b
(a / 2.) + (b / 2.)
cfg_if! {
if #[cfg(any(
target_arch = "x86_64",
target_arch = "aarch64",
all(any(target_arch="riscv32", target_arch= "riscv64"), target_feature="d"),
all(target_arch = "arm", target_feature="vfp2"),
target_arch = "wasm32",
target_arch = "wasm64",
))] {
// whitelist the faster implementation to targets that have known good 64-bit float
// implementations. Falling back to the branchy code on targets that don't have
// 64-bit hardware floats or buggy implementations.
// see: https://github.com/rust-lang/rust/pull/121062#issuecomment-2123408114
((f64::from(self) + f64::from(other)) / 2.0) as f32
} else {
const LO: f32 = f32::MIN_POSITIVE * 2.;
const HI: f32 = f32::MAX / 2.;

let (a, b) = (self, other);
let abs_a = a.abs_private();
let abs_b = b.abs_private();

if abs_a <= HI && abs_b <= HI {
// Overflow is impossible
(a + b) / 2.
} else if abs_a < LO {
// Not safe to halve a
a + (b / 2.)
} else if abs_b < LO {
// Not safe to halve b
(a / 2.) + b
} else {
// Not safe to halve a and b
(a / 2.) + (b / 2.)
}
}
}
}

Expand Down
29 changes: 26 additions & 3 deletions library/core/tests/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ assume_usize_width! {
}

macro_rules! test_float {
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => {
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr, $max_exp:expr) => {
mod $modname {
#[test]
fn min() {
Expand Down Expand Up @@ -870,6 +870,27 @@ macro_rules! test_float {
assert!(($nan as $fty).midpoint(1.0).is_nan());
assert!((1.0 as $fty).midpoint($nan).is_nan());
assert!(($nan as $fty).midpoint($nan).is_nan());

// test if large differences in magnitude are still correctly computed.
// NOTE: that because of how small x and y are, x + y can never overflow
// so (x + y) / 2.0 is always correct
// in particular, `2.pow(i)` will never be at the max exponent, so it could
// be safely doubled, while j is significantly smaller.
for i in $max_exp.saturating_sub(64)..$max_exp {
for j in 0..64u8 {
let large = <$fty>::from(2.0f32).powi(i);
// a much smaller number, such that there is no chance of overflow to test
// potential double rounding in midpoint's implementation.
let small = <$fty>::from(2.0f32).powi($max_exp - 1)
* <$fty>::EPSILON
* <$fty>::from(j);

let naive = (large + small) / 2.0;
let midpoint = large.midpoint(small);

assert_eq!(naive, midpoint);
}
}
}
#[test]
fn rem_euclid() {
Expand Down Expand Up @@ -902,7 +923,8 @@ test_float!(
f32::NAN,
f32::MIN,
f32::MAX,
f32::MIN_POSITIVE
f32::MIN_POSITIVE,
f32::MAX_EXP
);
test_float!(
f64,
Expand All @@ -912,5 +934,6 @@ test_float!(
f64::NAN,
f64::MIN,
f64::MAX,
f64::MIN_POSITIVE
f64::MIN_POSITIVE,
f64::MAX_EXP
);