@@ -180,16 +180,15 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
180180// - input splits (IN)
181181// - output splits (OUT) and
182182// - source offsets (OUT).
183- __global__ void exchangeSplitAndOffset (int64_t * in_out_splits , nvshmem_team_t team) {
183+ __global__ void exchangeSplitAndOffset (int64_t * input_splits, int64_t * out_splits_offsets , nvshmem_team_t team) {
184184#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
185185 CUDA_KERNEL_ASSERT_MSG (false , " SM arch unsupported for NVSHMEM" );
186186#else
187187 CUDA_KERNEL_ASSERT (team != NVSHMEM_TEAM_INVALID);
188188 int mype = nvshmem_team_my_pe (team);
189189 int npes = nvshmem_team_n_pes (team);
190- auto input_splits = in_out_splits;
191- auto output_splits = in_out_splits + npes;
192- auto source_offsets = in_out_splits + npes * 2 ;
190+ auto output_splits = out_splits_offsets;
191+ auto source_offsets = out_splits_offsets + npes;
193192 int tid = threadIdx .x ;
194193
195194 CUDA_KERNEL_ASSERT (npes <= THREADS_PER_BLOCK);
@@ -214,15 +213,15 @@ __global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t te
214213// This kernel is used to do the actual data exchange.
215214// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
216215// `stride` is the stride at dim 0, unit in byte.
217- __global__ void allToAllV (void *send_data, void *recv_data, int64_t * in_out_splits , size_t stride, nvshmem_team_t team) {
216+ __global__ void allToAllV (void *send_data, void *recv_data, int64_t * out_splits_offsets , size_t stride, nvshmem_team_t team) {
218217#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
219218 CUDA_KERNEL_ASSERT_MSG (false , " SM arch unsupported for NVSHMEM" );
220219#else
221220 CUDA_KERNEL_ASSERT (team != NVSHMEM_TEAM_INVALID);
222221 int mype = nvshmem_team_my_pe (team);
223222 int npes = nvshmem_team_n_pes (team);
224- auto output_splits = in_out_splits + npes ;
225- auto source_offsets = in_out_splits + npes * 2 ;
223+ auto output_splits = out_splits_offsets ;
224+ auto source_offsets = out_splits_offsets + npes;
226225 int bid = blockIdx .x ;
227226 int tid = threadIdx .x ;
228227 int blocks_per_peer = max (gridDim .x / npes, 1 );
@@ -277,29 +276,31 @@ static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
277276 return std::min (num_blocks, max_blocks);
278277}
279278
280- at::Tensor all_to_all_vdev (
279+ void all_to_all_vdev (
281280 at::Tensor& input,
282281 at::Tensor& out,
283- at::Tensor& in_out_splits,
282+ at::Tensor& in_splits,
283+ at::Tensor& out_splits_offsets,
284284 std::string group_name) {
285285 /* Perform AllToAllv operation using NVSHMEM, with split information provided on device.
286286 * Arguments:
287287 * - `input` is the input tensor
288288 * - `out` is the output tensor
289- * - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order):
290- input splits (IN)
291- output splits (OUT) and
292- output offsets (OUT).
289+ * - `in_splits` is a 1D tensor of size (npes), containing the input splits
290+ * - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order):
291+ output splits and output offsets.
293292 */
294293 auto input_hdl = c10d::symmetric_memory::rendezvous (input, group_name);
295294 auto out_hdl = c10d::symmetric_memory::rendezvous (out, group_name);
296- auto splits_hdl = c10d::symmetric_memory::rendezvous (in_out_splits, group_name);
295+ auto in_splits_hdl = c10d::symmetric_memory::rendezvous (in_splits, group_name);
296+ auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous (out_splits_offsets, group_name);
297297 int rank = input_hdl->get_rank ();
298298 int world_size = input_hdl->get_world_size ();
299299
300300 void * input_ptr = input.data_ptr ();
301301 void * output_ptr = out.mutable_data_ptr ();
302- int64_t * splits_ptr = (int64_t *)(in_out_splits.mutable_data_ptr ());
302+ int64_t * in_splits_ptr = (int64_t *)(in_splits.const_data_ptr ());
303+ int64_t * out_splits_offsets_ptr = (int64_t *)(out_splits_offsets.mutable_data_ptr ());
303304
304305 TORCH_CHECK_EQ (input.device (), out.device ());
305306 auto device = input.device ();
@@ -311,7 +312,8 @@ at::Tensor all_to_all_vdev(
311312 // Exchange output splits and source offsets
312313 // Use collective launch because kernel involves nvshmem barrier
313314 void * args0[] = {
314- &splits_ptr,
315+ &in_splits_ptr,
316+ &out_splits_offsets_ptr,
315317 &team};
316318 nvshmemx_collective_launch (
317319 (const void *)exchangeSplitAndOffset,
@@ -335,7 +337,7 @@ at::Tensor all_to_all_vdev(
335337 void * args1[] = {
336338 &input_ptr,
337339 &output_ptr,
338- &splits_ptr ,
340+ &out_splits_offsets_ptr ,
339341 &stride_bytes,
340342 &team};
341343 nvshmemx_collective_launch (
@@ -345,7 +347,6 @@ at::Tensor all_to_all_vdev(
345347 args1,
346348 0 ,
347349 stream);
348- return out;
349350}
350351
351352// Start of `all_to_all_vdev_2d`
0 commit comments