55from contextlib import nullcontext
66
77
8- __all__ = ["init" , "nccl_mem" ]
8+ __all__ = ["init" , "nccl_mem" , "create_nccl_mem_pool" ]
99
1010
11- _allocator = _apex_nccl_allocator .get_nccl_allocator ()
12- _pool = torch .cuda .MemPool (_allocator )
11+ def create_nccl_mem_pool ():
12+ _allocator = _apex_nccl_allocator .get_nccl_allocator ()
13+ _pool = torch .cuda .MemPool (_allocator )
14+ return _pool
1315
1416
1517def init () -> None :
@@ -18,10 +20,11 @@ def init() -> None:
1820
1921
2022class nccl_mem :
21- def __init__ (self , enabled = True , device = None , group = None ):
23+ def __init__ (self , pool , enabled = True , device = None , group = None ):
2224 self .device = None
2325 self .group = None
2426 self .mem_context = None
27+ self .pool = pool
2528
2629 if enabled :
2730 if device is None :
@@ -37,7 +40,7 @@ def __init__(self, enabled = True, device = None, group = None):
3740 else :
3841 self .group = group
3942
40- self .mem_context = torch .cuda .use_mem_pool (_pool )
43+ self .mem_context = torch .cuda .use_mem_pool (self . pool )
4144 else :
4245 self .mem_context = nullcontext ()
4346
@@ -46,15 +49,15 @@ def __enter__(self):
4649 if self .group is not None :
4750 backend = self .group ._get_backend (self .device )
4851 try :
49- backend .deregister_mem_pool (_pool )
52+ backend .deregister_mem_pool (self . pool )
5053 except RuntimeError :
5154 pass
5255
5356 def __exit__ (self , * args ):
5457 if self .group is not None :
5558 backend = self .group ._get_backend (self .device )
5659 try :
57- backend .register_mem_pool (_pool )
60+ backend .register_mem_pool (self . pool )
5861 except RuntimeError :
5962 pass
6063 self .mem_context .__exit__ (* args )
0 commit comments