Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 312acb4

Browse files
authored
[nccl_allocator] Adds helper API to create pool (#1877)
1 parent 0024663 commit 312acb4

6 files changed

Lines changed: 22 additions & 15 deletions

File tree

apex/contrib/examples/nccl_allocator/allreduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
torch.cuda.set_device(local_rank)
1515
dist.init_process_group(backend="nccl")
16-
17-
with nccl_allocator.nccl_mem():
16+
pool = nccl_allocator.create_nccl_mem_pool()
17+
with nccl_allocator.nccl_mem(pool):
1818
a = torch.ones(1024 * 1024 * 2, device="cuda")
1919
dist.all_reduce(a)
2020

apex/contrib/examples/nccl_allocator/cache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def print_used_mem(string, nvsmi, device_id = 0):
2121

2222
print_used_mem("", nvsmi)
2323

24-
with nccl_allocator.nccl_mem():
24+
pool = nccl_allocator.create_nccl_mem_pool()
25+
with nccl_allocator.nccl_mem(pool):
2526
for i in range(nrep):
2627
out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB
2728
nccl_mem.append(out)
@@ -42,7 +43,7 @@ def print_used_mem(string, nvsmi, device_id = 0):
4243

4344
del nccl_mem
4445
nccl_mem2 = []
45-
with nccl_allocator.nccl_mem():
46+
with nccl_allocator.nccl_mem(pool):
4647
for i in range(nrep):
4748
out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB
4849
nccl_mem2.append(out)

apex/contrib/examples/nccl_allocator/change_cuda_allocator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
nccl_allocator.init()
55
nrep = 6
6-
with nccl_allocator.nccl_mem():
6+
pool = nccl_allocator.create_nccl_mem_pool()
7+
with nccl_allocator.nccl_mem(pool):
78
for i in range(nrep):
89
out = torch.randn(1024).cuda()
910

apex/contrib/examples/nccl_allocator/toy_ddp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def forward(self, x):
3636
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
3737

3838
data_ptrs = []
39-
with nccl_allocator.nccl_mem():
39+
pool = nccl_allocator.create_nccl_mem_pool()
40+
with nccl_allocator.nccl_mem(pool):
4041
for param in ddp_model.parameters():
4142
param.grad = torch.empty_like(param)
4243
data_ptrs.append(param.grad.data_ptr())

apex/contrib/nccl_allocator/nccl_allocator.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from contextlib import nullcontext
66

77

8-
__all__ = ["init", "nccl_mem"]
8+
__all__ = ["init", "nccl_mem", "create_nccl_mem_pool"]
99

1010

11-
_allocator = _apex_nccl_allocator.get_nccl_allocator()
12-
_pool = torch.cuda.MemPool(_allocator)
11+
def create_nccl_mem_pool():
12+
_allocator = _apex_nccl_allocator.get_nccl_allocator()
13+
_pool = torch.cuda.MemPool(_allocator)
14+
return _pool
1315

1416

1517
def init() -> None:
@@ -18,10 +20,11 @@ def init() -> None:
1820

1921

2022
class nccl_mem:
21-
def __init__(self, enabled = True, device = None, group = None):
23+
def __init__(self, pool, enabled = True, device = None, group = None):
2224
self.device = None
2325
self.group = None
2426
self.mem_context = None
27+
self.pool = pool
2528

2629
if enabled:
2730
if device is None:
@@ -37,7 +40,7 @@ def __init__(self, enabled = True, device = None, group = None):
3740
else:
3841
self.group = group
3942

40-
self.mem_context = torch.cuda.use_mem_pool(_pool)
43+
self.mem_context = torch.cuda.use_mem_pool(self.pool)
4144
else:
4245
self.mem_context = nullcontext()
4346

@@ -46,15 +49,15 @@ def __enter__(self):
4649
if self.group is not None:
4750
backend = self.group._get_backend(self.device)
4851
try:
49-
backend.deregister_mem_pool(_pool)
52+
backend.deregister_mem_pool(self.pool)
5053
except RuntimeError:
5154
pass
5255

5356
def __exit__(self, *args):
5457
if self.group is not None:
5558
backend = self.group._get_backend(self.device)
5659
try:
57-
backend.register_mem_pool(_pool)
60+
backend.register_mem_pool(self.pool)
5861
except RuntimeError:
5962
pass
6063
self.mem_context.__exit__(*args)

apex/contrib/optimizers/distributed_fused_adam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,12 +1191,13 @@ def _init_grad_buffer(self) -> None:
11911191
[buffer_size], dtype=grad_sync_dtype, device=self.device,
11921192
)
11931193
else:
1194-
with nccl_allocator.nccl_mem():
1194+
pool = nccl_allocator.create_nccl_mem_pool()
1195+
with nccl_allocator.nccl_mem(pool):
11951196
self._grad_buffers[dtypes] = torch.zeros(
11961197
[buffer_size], dtype=grad_sync_dtype, device=self.device,
11971198
)
11981199
shard_buffer_size = buffer_size // self.distributed_size
1199-
with nccl_allocator.nccl_mem():
1200+
with nccl_allocator.nccl_mem(pool):
12001201
self._shard_grad_buffers[dtypes] = torch.zeros(
12011202
[shard_buffer_size], dtype=grad_sync_dtype, device=self.device,
12021203
)

0 commit comments

Comments
 (0)