File tree Expand file tree Collapse file tree
apex/contrib/nccl_allocator Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -17,9 +17,15 @@ def create_nccl_mem_pool(symmetric: bool | None = None) -> torch.cuda.MemPool:
1717 if symmetric is None :
1818 _pool = torch .cuda .MemPool (_allocator )
1919 else :
20- assert 'symmetric' in get_func_args (torch .cuda .MemPool ), \
21- "symmetric setting with torch.cuda.MemPool requires higher PyTorch version"
22- _pool = torch .cuda .MemPool (_allocator , symmetric = symmetric )
20+ if 'symmetric' in get_func_args (torch .cuda .MemPool ):
21+ _pool = torch .cuda .MemPool (_allocator , symmetric = symmetric )
22+ elif 'symm_mem' in get_func_args (torch .cuda .MemPool ):
23+ # This path handles argument name divergence between
24+ # nvidia pytorch and the official pytorch.
25+ _pool = torch .cuda .MemPool (_allocator , symm_mem = symmetric )
26+ else :
27+ raise ValueError ("symmetric setting with torch.cuda.MemPool requires "
28+ "higher PyTorch version" )
2329 return _pool
2430
2531
You can’t perform that action at this time.
0 commit comments