Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 70660fe

Browse files
fix(ndarray): broadcast remainder operands (#5002)
1 parent 412aa66 commit 70660fe

2 files changed

Lines changed: 33 additions & 20 deletions

File tree

crates/burn-backend-tests/tests/tensor/int/ops/remainder.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ fn should_support_int_remainder_basic() {
1414
output.into_data().assert_eq(&expected, false);
1515
}
1616

17+
#[test]
18+
fn should_support_int_remainder_broadcast() {
19+
let device = Default::default();
20+
let lhs = TestTensorInt::<2>::from_data(TensorData::from([[10, 20, 30]]), &device);
21+
let rhs = TestTensorInt::<2>::from_data(TensorData::from([[7]]), &device);
22+
23+
let output = lhs.remainder(rhs);
24+
let expected = TensorData::from([[3, 6, 2]]);
25+
26+
output.into_data().assert_eq(&expected, false);
27+
}
28+
1729
#[test]
1830
fn should_support_int_remainder_basic_scalar() {
1931
let data = TensorData::from([-3, -2, -1, 1, 2, 3]);

crates/burn-ndarray/src/ops/base.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -673,9 +673,9 @@ macro_rules! dispatch_unary_simd {
673673
($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }};
674674
}
675675

676-
// Helper function to broadcast two tensors to a common shape for comparison operations
676+
// Helper function to broadcast two tensors to a common shape for binary operations
677677
// Returns broadcasted views that can be safely zipped
678-
fn broadcast_for_comparison<'a, E: Copy, S1, S2>(
678+
fn broadcast_for_binary_ops<'a, E: Copy, S1, S2>(
679679
lhs: &'a ndarray::ArrayBase<S1, ndarray::IxDyn>,
680680
rhs: &'a ndarray::ArrayBase<S2, ndarray::IxDyn>,
681681
) -> (
@@ -842,16 +842,17 @@ where
842842
}
843843

844844
pub fn remainder(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
845-
// Use into_owned() instead of clone() - only copies if shared, avoids copy if unique
846-
let mut out = lhs.into_owned();
847-
Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| {
848-
// out_elem holds lhs value; read it before overwriting with remainder
849-
let a_f = (*out_elem).to_f64();
850-
let b_f = b.to_f64();
851-
let r = a_f - b_f * (a_f / b_f).floor();
852-
*out_elem = r.elem();
853-
});
854-
out.into_shared()
845+
let (lhs, rhs) = broadcast_for_binary_ops(&lhs, &rhs);
846+
847+
Zip::from(&lhs)
848+
.and(&rhs)
849+
.map_collect(|&a, &b| {
850+
let a_f = a.to_f64();
851+
let b_f = b.to_f64();
852+
let r = a_f - b_f * (a_f / b_f).floor();
853+
r.elem()
854+
})
855+
.into_shared()
855856
}
856857

857858
pub fn remainder_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E>
@@ -983,7 +984,7 @@ where
983984
);
984985

985986
// Use the helper to broadcast both arrays to a common shape
986-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
987+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
987988
// Now we can safely zip and compare
988989
Zip::from(&lhs_broadcast)
989990
.and(&rhs_broadcast)
@@ -1183,7 +1184,7 @@ where
11831184
);
11841185

11851186
// Use the helper to broadcast both arrays to a common shape
1186-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1187+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
11871188
// Now we can safely zip and compare
11881189
Zip::from(&lhs_broadcast)
11891190
.and(&rhs_broadcast)
@@ -1231,7 +1232,7 @@ where
12311232
);
12321233

12331234
// Use the helper to broadcast both arrays to a common shape
1234-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1235+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
12351236
// Now we can safely zip and compare
12361237
Zip::from(&lhs_broadcast)
12371238
.and(&rhs_broadcast)
@@ -1266,7 +1267,7 @@ where
12661267
);
12671268

12681269
// Use the helper to broadcast both arrays to a common shape
1269-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1270+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
12701271
// Now we can safely zip and compare
12711272
Zip::from(&lhs_broadcast)
12721273
.and(&rhs_broadcast)
@@ -1301,7 +1302,7 @@ where
13011302
);
13021303

13031304
// Use the helper to broadcast both arrays to a common shape
1304-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1305+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
13051306

13061307
// Now we can safely zip and compare
13071308
Zip::from(&lhs_broadcast)
@@ -1446,7 +1447,7 @@ impl NdArrayBoolOps {
14461447
};
14471448

14481449
// Use the helper to broadcast both arrays to a common shape
1449-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1450+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
14501451
// Now we can safely zip and compare
14511452
Zip::from(&lhs_broadcast)
14521453
.and(&rhs_broadcast)
@@ -1472,7 +1473,7 @@ impl NdArrayBoolOps {
14721473
};
14731474

14741475
// Use the helper to broadcast both arrays to a common shape
1475-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1476+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
14761477
// Now we can safely zip and compare
14771478
Zip::from(&lhs_broadcast)
14781479
.and(&rhs_broadcast)
@@ -1488,7 +1489,7 @@ impl NdArrayBoolOps {
14881489
};
14891490

14901491
// Use the helper to broadcast both arrays to a common shape
1491-
let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1492+
let (lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops(&lhs, &rhs);
14921493
// Now we can safely zip and compare
14931494
Zip::from(&lhs_broadcast)
14941495
.and(&rhs_broadcast)

0 commit comments

Comments
 (0)