Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 31c72b8

Browse files
pytorchbotkwen2501
andauthored
[a2av] Separate in/out splits into two tensors (#164028)
[a2av] Separate in/out splits into two tensors (#163837) Old signature: `all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name)` New signature: `all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name)` i.e. split `in_out_splits` into IN tensor and OUT tensor so that we can define the TORCH_LIBRARY signature better. Also to be in line with the 2D version. Pull Request resolved: #163837 Approved by: https://github.com/fduwjj ghstack dependencies: #163886 (cherry picked from commit bbf8aa4) Co-authored-by: Ke Wen <[email protected]>
1 parent 1cd83de commit 31c72b8

File tree

4 files changed

+36
-29
lines changed

4 files changed

+36
-29
lines changed

test/distributed/test_nvshmem.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,28 +299,33 @@ def test_all_to_all_vdev(self) -> None:
299299
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
300300
)
301301
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
302-
in_out_splits = symm_mem.empty(
303-
(3, self.world_size), dtype=torch.int64, device=self.device
302+
in_splits = symm_mem.empty(
303+
self.world_size, dtype=torch.int64, device=self.device
304+
)
305+
out_splits_offsets = symm_mem.empty(
306+
(2, self.world_size), dtype=torch.int64, device=self.device
304307
)
305308
# Row 0 is input splits
306-
in_out_splits[0].copy_(inp_splits)
309+
in_splits.copy_(inp_splits)
307310

308311
# Sync all ranks to ensure remote tensors are allocated
309312
dist.barrier()
310313

311-
torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)
314+
torch.ops.symm_mem.all_to_all_vdev(
315+
inp, out, in_splits, out_splits_offsets, group_name
316+
)
312317

313318
# Check input splits (row 0) -- should not change
314-
torch.testing.assert_close(in_out_splits[0], inp_splits)
319+
torch.testing.assert_close(in_splits, inp_splits)
315320

316321
# Check output splits (row 1)
317-
torch.testing.assert_close(in_out_splits[1], out_splits)
322+
torch.testing.assert_close(out_splits_offsets[0], out_splits)
318323

319324
# Check output offsets (row 2)
320325
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
321326
# output offsets from `all_to_all_vdev` is exclusive scan
322-
self.assertEqual(in_out_splits[2][0], 0)
323-
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
327+
self.assertEqual(out_splits_offsets[1][0], 0)
328+
torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1])
324329

325330
# Check data
326331
expected = torch.empty(out_numel, dtype=dtype, device=self.device)

torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
502502
m.def(
503503
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
504504
m.def(
505-
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
505+
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name) -> ()");
506506
m.def(
507507
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
508508
m.def(

torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,15 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
180180
// - input splits (IN)
181181
// - output splits (OUT) and
182182
// - source offsets (OUT).
183-
__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t team) {
183+
__global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, nvshmem_team_t team) {
184184
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
185185
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
186186
#else
187187
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
188188
int mype = nvshmem_team_my_pe(team);
189189
int npes = nvshmem_team_n_pes(team);
190-
auto input_splits = in_out_splits;
191-
auto output_splits = in_out_splits + npes;
192-
auto source_offsets = in_out_splits + npes * 2;
190+
auto output_splits = out_splits_offsets;
191+
auto source_offsets = out_splits_offsets + npes;
193192
int tid = threadIdx.x;
194193

195194
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
@@ -214,15 +213,15 @@ __global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t te
214213
// This kernel is used to do the actual data exchange.
215214
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
216215
// `stride` is the stride at dim 0, unit in byte.
217-
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, nvshmem_team_t team) {
216+
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, nvshmem_team_t team) {
218217
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
219218
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
220219
#else
221220
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
222221
int mype = nvshmem_team_my_pe(team);
223222
int npes = nvshmem_team_n_pes(team);
224-
auto output_splits = in_out_splits + npes;
225-
auto source_offsets = in_out_splits + npes * 2;
223+
auto output_splits = out_splits_offsets;
224+
auto source_offsets = out_splits_offsets + npes;
226225
int bid = blockIdx.x;
227226
int tid = threadIdx.x;
228227
int blocks_per_peer = max(gridDim.x / npes, 1);
@@ -277,29 +276,31 @@ static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
277276
return std::min(num_blocks, max_blocks);
278277
}
279278

280-
at::Tensor all_to_all_vdev(
279+
void all_to_all_vdev(
281280
at::Tensor& input,
282281
at::Tensor& out,
283-
at::Tensor& in_out_splits,
282+
at::Tensor& in_splits,
283+
at::Tensor& out_splits_offsets,
284284
std::string group_name) {
285285
/* Perform AllToAllv operation using NVSHMEM, with split information provided on device.
286286
* Arguments:
287287
* - `input` is the input tensor
288288
* - `out` is the output tensor
289-
* - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order):
290-
input splits (IN)
291-
output splits (OUT) and
292-
output offsets (OUT).
289+
* - `in_splits` is a 1D tensor of size (npes), containing the input splits
290+
* - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order):
291+
output splits and output offsets.
293292
*/
294293
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
295294
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
296-
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
295+
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
296+
auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
297297
int rank = input_hdl->get_rank();
298298
int world_size = input_hdl->get_world_size();
299299

300300
void* input_ptr = input.data_ptr();
301301
void* output_ptr = out.mutable_data_ptr();
302-
int64_t* splits_ptr = (int64_t*)(in_out_splits.mutable_data_ptr());
302+
int64_t* in_splits_ptr = (int64_t*)(in_splits.const_data_ptr());
303+
int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr());
303304

304305
TORCH_CHECK_EQ(input.device(), out.device());
305306
auto device = input.device();
@@ -311,7 +312,8 @@ at::Tensor all_to_all_vdev(
311312
// Exchange output splits and source offsets
312313
// Use collective launch because kernel involves nvshmem barrier
313314
void* args0[] = {
314-
&splits_ptr,
315+
&in_splits_ptr,
316+
&out_splits_offsets_ptr,
315317
&team};
316318
nvshmemx_collective_launch(
317319
(const void*)exchangeSplitAndOffset,
@@ -335,7 +337,7 @@ at::Tensor all_to_all_vdev(
335337
void* args1[] = {
336338
&input_ptr,
337339
&output_ptr,
338-
&splits_ptr,
340+
&out_splits_offsets_ptr,
339341
&stride_bytes,
340342
&team};
341343
nvshmemx_collective_launch(
@@ -345,7 +347,6 @@ at::Tensor all_to_all_vdev(
345347
args1,
346348
0,
347349
stream);
348-
return out;
349350
}
350351

351352
// Start of `all_to_all_vdev_2d`

torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ at::Tensor nvshmem_all_to_all(
3232
at::Tensor& out,
3333
std::string group_name);
3434

35-
at::Tensor all_to_all_vdev(
35+
void all_to_all_vdev(
3636
at::Tensor& input,
3737
at::Tensor& out,
38-
at::Tensor& in_out_splits,
38+
at::Tensor& in_splits,
39+
at::Tensor& out_splits_offsets,
3940
std::string group_name);
4041

4142
void all_to_all_vdev_2d(

0 commit comments

Comments
 (0)