@@ -673,9 +673,9 @@ macro_rules! dispatch_unary_simd {
673673 ( $elem: ty, $op: ty, $lhs: expr, $( $ty: ty) ,* ) => { { $lhs } } ;
674674}
675675
676- // Helper function to broadcast two tensors to a common shape for comparison operations
676+ // Helper function to broadcast two tensors to a common shape for binary operations
677677// Returns broadcasted views that can be safely zipped
678- fn broadcast_for_comparison < ' a , E : Copy , S1 , S2 > (
678+ fn broadcast_for_binary_ops < ' a , E : Copy , S1 , S2 > (
679679 lhs : & ' a ndarray:: ArrayBase < S1 , ndarray:: IxDyn > ,
680680 rhs : & ' a ndarray:: ArrayBase < S2 , ndarray:: IxDyn > ,
681681) -> (
@@ -842,16 +842,17 @@ where
842842 }
843843
844844 pub fn remainder ( lhs : SharedArray < E > , rhs : SharedArray < E > ) -> SharedArray < E > {
845- // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique
846- let mut out = lhs. into_owned ( ) ;
847- Zip :: from ( & mut out) . and ( & rhs) . for_each ( |out_elem, & b| {
848- // out_elem holds lhs value; read it before overwriting with remainder
849- let a_f = ( * out_elem) . to_f64 ( ) ;
850- let b_f = b. to_f64 ( ) ;
851- let r = a_f - b_f * ( a_f / b_f) . floor ( ) ;
852- * out_elem = r. elem ( ) ;
853- } ) ;
854- out. into_shared ( )
845+ let ( lhs, rhs) = broadcast_for_binary_ops ( & lhs, & rhs) ;
846+
847+ Zip :: from ( & lhs)
848+ . and ( & rhs)
849+ . map_collect ( |& a, & b| {
850+ let a_f = a. to_f64 ( ) ;
851+ let b_f = b. to_f64 ( ) ;
852+ let r = a_f - b_f * ( a_f / b_f) . floor ( ) ;
853+ r. elem ( )
854+ } )
855+ . into_shared ( )
855856 }
856857
857858 pub fn remainder_scalar ( lhs : SharedArray < E > , rhs : E ) -> SharedArray < E >
@@ -983,7 +984,7 @@ where
983984 ) ;
984985
985986 // Use the helper to broadcast both arrays to a common shape
986- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
987+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
987988 // Now we can safely zip and compare
988989 Zip :: from ( & lhs_broadcast)
989990 . and ( & rhs_broadcast)
@@ -1183,7 +1184,7 @@ where
11831184 ) ;
11841185
11851186 // Use the helper to broadcast both arrays to a common shape
1186- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1187+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
11871188 // Now we can safely zip and compare
11881189 Zip :: from ( & lhs_broadcast)
11891190 . and ( & rhs_broadcast)
@@ -1231,7 +1232,7 @@ where
12311232 ) ;
12321233
12331234 // Use the helper to broadcast both arrays to a common shape
1234- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1235+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
12351236 // Now we can safely zip and compare
12361237 Zip :: from ( & lhs_broadcast)
12371238 . and ( & rhs_broadcast)
@@ -1266,7 +1267,7 @@ where
12661267 ) ;
12671268
12681269 // Use the helper to broadcast both arrays to a common shape
1269- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1270+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
12701271 // Now we can safely zip and compare
12711272 Zip :: from ( & lhs_broadcast)
12721273 . and ( & rhs_broadcast)
@@ -1301,7 +1302,7 @@ where
13011302 ) ;
13021303
13031304 // Use the helper to broadcast both arrays to a common shape
1304- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1305+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
13051306
13061307 // Now we can safely zip and compare
13071308 Zip :: from ( & lhs_broadcast)
@@ -1446,7 +1447,7 @@ impl NdArrayBoolOps {
14461447 } ;
14471448
14481449 // Use the helper to broadcast both arrays to a common shape
1449- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1450+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
14501451 // Now we can safely zip and compare
14511452 Zip :: from ( & lhs_broadcast)
14521453 . and ( & rhs_broadcast)
@@ -1472,7 +1473,7 @@ impl NdArrayBoolOps {
14721473 } ;
14731474
14741475 // Use the helper to broadcast both arrays to a common shape
1475- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1476+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
14761477 // Now we can safely zip and compare
14771478 Zip :: from ( & lhs_broadcast)
14781479 . and ( & rhs_broadcast)
@@ -1488,7 +1489,7 @@ impl NdArrayBoolOps {
14881489 } ;
14891490
14901491 // Use the helper to broadcast both arrays to a common shape
1491- let ( lhs_broadcast, rhs_broadcast) = broadcast_for_comparison ( & lhs, & rhs) ;
1492+ let ( lhs_broadcast, rhs_broadcast) = broadcast_for_binary_ops ( & lhs, & rhs) ;
14921493 // Now we can safely zip and compare
14931494 Zip :: from ( & lhs_broadcast)
14941495 . and ( & rhs_broadcast)
0 commit comments