Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2eda0ac

Browse files
[contrib/nccl_allocator] Add an interface for the NCCL symmetric registration (#1915)
* add an interface for symmetric reg Signed-off-by: Youngeun Kwon <[email protected]> * update Signed-off-by: Youngeun Kwon <[email protected]> * Update apex/contrib/nccl_allocator/nccl_allocator.py Co-authored-by: Masaki Kozuki <[email protected]> --------- Signed-off-by: Youngeun Kwon <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]>
1 parent d8200a1 commit 2eda0ac

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

apex/contrib/nccl_allocator/nccl_allocator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,19 @@
77

88
__all__ = ["init", "nccl_mem", "create_nccl_mem_pool"]
99

10+
def get_func_args(func):
11+
import inspect
12+
sig = inspect.signature(func)
13+
return [arg.name for arg in sig.parameters.values()]
1014

11-
def create_nccl_mem_pool():
15+
def create_nccl_mem_pool(symmetric: bool | None = None) -> torch.cuda.MemPool:
1216
_allocator = _apex_nccl_allocator.get_nccl_allocator()
13-
_pool = torch.cuda.MemPool(_allocator)
17+
if symmetric is None:
18+
_pool = torch.cuda.MemPool(_allocator)
19+
else:
20+
assert 'symmetric' in get_func_args(torch.cuda.MemPool), \
21+
"symmetric setting with torch.cuda.MemPool requires higher PyTorch version"
22+
_pool = torch.cuda.MemPool(_allocator, symmetric=symmetric)
1423
return _pool
1524

1625

0 commit comments

Comments
 (0)