Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d8200a1

Browse files
authored
replace apex.parallel.flat_dist_call in test with coalescing manager (#1913)
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent a255aa6 commit d8200a1

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

apex/contrib/test/optimizers/test_distributed_fused_lamb.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
import os
21
import inspect
2+
33
import torch
44
from torch.cuda.amp import GradScaler
55
from torch.testing._internal import common_utils
6-
from apex.parallel.distributed import flat_dist_call
6+
from torch.distributed.distributed_c10d import _coalescing_manager
7+
78
from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB
89
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
910

11+
12+
def flat_dist_call(param_list: list[torch.Tensor], op, args):
13+
with _coalescing_manager(async_ops=True) as cm:
14+
for p in param_list:
15+
op(p, *args)
16+
17+
cm.wait()
18+
19+
1020
def get_init_weights_func():
1121
@torch.no_grad()
1222
def init_weights(m):

0 commit comments

Comments
 (0)