Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bf903a2

Browse files
Handle temporary arg name divergence in the torch.cuda.MemPool (#1920)
Signed-off-by: Youngeun Kwon <[email protected]>
1 parent bfb500c commit bf903a2

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

apex/contrib/nccl_allocator/nccl_allocator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@ def create_nccl_mem_pool(symmetric: bool | None = None) -> torch.cuda.MemPool:
1717
if symmetric is None:
1818
_pool = torch.cuda.MemPool(_allocator)
1919
else:
20-
assert 'symmetric' in get_func_args(torch.cuda.MemPool), \
21-
"symmetric setting with torch.cuda.MemPool requires higher PyTorch version"
22-
_pool = torch.cuda.MemPool(_allocator, symmetric=symmetric)
20+
if 'symmetric' in get_func_args(torch.cuda.MemPool):
21+
_pool = torch.cuda.MemPool(_allocator, symmetric=symmetric)
22+
elif 'symm_mem' in get_func_args(torch.cuda.MemPool):
23+
# This path handles argument name divergence between
24+
# nvidia pytorch and the official pytorch.
25+
_pool = torch.cuda.MemPool(_allocator, symm_mem=symmetric)
26+
else:
27+
raise ValueError("symmetric setting with torch.cuda.MemPool requires "
28+
"higher PyTorch version")
2329
return _pool
2430

2531

0 commit comments

Comments
 (0)