Thanks to visit codestin.com
Credit goes to github.com

Skip to content

dist.barrier() hangs after calling async_save #123447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
LucasLLC opened this issue Apr 5, 2024 · 9 comments
Open

dist.barrier() hangs after calling async_save #123447

LucasLLC opened this issue Apr 5, 2024 · 9 comments
Assignees
Labels
has workaround oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@LucasLLC
Copy link
Contributor

LucasLLC commented Apr 5, 2024

πŸ› Describe the bug

Reproduced below, dist.barrier() fails after calls to torch.distributed.checkpoint.async_save.

Interestingly enough, this does not happen if we first call all_reduce (commented out in the example).

Potentially related to: #95895 (comment)

import os
import torch
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn as nn
import shutil

CHECKPOINT_DIR = "test_checkpoint"

class TestDummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(0)
        self.net1 = nn.Linear(8, 16)
        self.net2 = nn.Linear(16, 32)
        self.net3 = nn.Linear(32, 64)
        self.net4 = nn.Linear(64, 8)

    def forward(self, x):
        x = F.relu(self.net1(x))
        x = F.relu(self.net2(x))
        x = F.relu(self.net3(x))
        x = F.relu(self.net4(x))
        return x


def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    device = "cuda"
    #device = "cpu"

    dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)

    model = TestDummyModel().to(device)
    state_dict = model.state_dict()

    f = dcp.async_save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    # dist.all_reduce(torch.zeros((1,), device=device)) for some reason, running this first
    # makes the next line not hang.
    dist.barrier()
    if isinstance(f.result(), Exception):
        print(f.result())
    else:
        print("finished saving")


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running on {world_size} devices.")
    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)

    mp.spawn(
        run,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

Versions

PyTorch version: 2.4.0a0+gite0c9764
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: 17.0.6 (CentOS 17.0.6-5.el9)
CMake version: version 3.26.5
Libc version: glibc-2.34

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.12.0-0_fbk16_zion_7661_geb00762ce6d2-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA PG509-210
GPU 1: NVIDIA PG509-210
GPU 2: NVIDIA PG509-210
GPU 3: NVIDIA PG509-210
GPU 4: NVIDIA PG509-210
GPU 5: NVIDIA PG509-210
GPU 6: NVIDIA PG509-210
GPU 7: NVIDIA PG509-210

Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.9.4
/usr/lib64/libcudnn_adv_infer.so.8.9.4
/usr/lib64/libcudnn_adv_train.so.8.9.4
/usr/lib64/libcudnn_cnn_infer.so.8.9.4
/usr/lib64/libcudnn_cnn_train.so.8.9.4
/usr/lib64/libcudnn_ops_infer.so.8.9.4
/usr/lib64/libcudnn_ops_train.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn.so.8.9.2
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.2
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.2
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.2
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.2
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.2
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 4
Stepping: 11
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 1801.0000
CPU min MHz: 800.0000
BogoMIPS: 3600.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 3 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 96 MiB (96 instances)
L3 cache: 132 MiB (4 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-23,96-119
NUMA node1 CPU(s): 24-47,120-143
NUMA node2 CPU(s): 48-71,144-167
NUMA node3 CPU(s): 72-95,168-191
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==0.990
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] optree==0.10.0
[pip3] pytorch-triton==3.0.0+901819d2b6
[pip3] torch==2.4.0a0+gite0c9764
[pip3] torchaudio==2.2.0.dev20240129+cu121
[pip3] torchpippy==0.2.0+27d9145
[pip3] torchsnapshot==0.1.0
[pip3] torchtnt==0.2.3
[pip3] torchtune==0.0.1
[pip3] torchvision==0.17.0
[pip3] torchx==0.6.0
[pip3] triton==2.2.0
[conda] numpy 1.26.0 pypi_0 pypi
[conda] optree 0.10.0 pypi_0 pypi
[conda] pytorch-triton 3.0.0+901819d2b6 pypi_0 pypi
[conda] torch 2.4.0a0+gite0c9764 dev_0
[conda] torchaudio 2.2.0.dev20240129+cu121 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchpippy 0.2.0+27d9145 dev_0
[conda] torchsnapshot 0.1.0 dev_0
[conda] torchtnt 0.2.3 dev_0
[conda] torchtune 0.0.1 pypi_0 pypi
[conda] torchvision 0.18.0a0+b1123cf dev_0
[conda] torchx 0.6.0 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi
(torch) [[email protected] ~]$

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @rohan-varma @penguinwu

@LucasLLC LucasLLC self-assigned this Apr 5, 2024
@LucasLLC LucasLLC added release notes: distributed (checkpoint) oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 5, 2024
@kwen2501
Copy link
Contributor

kwen2501 commented Apr 5, 2024

dist.barrier() fails after calls to torch.distributed.checkpoint.async_save

How does it fail? Does you have an error log?

@LucasLLC
Copy link
Contributor Author

LucasLLC commented Apr 5, 2024

@kwen2501 I do not, but it hangs

@wconstab
Copy link
Contributor

wconstab commented Apr 9, 2024

@LucasLLC maybe you could get a little more info and share it. Like a stacktrace of where the program is hanging?

@LucasLLC
Copy link
Contributor Author

LucasLLC commented Apr 16, 2024

Creating a new pg seemed to alleviate the issue for me. This is probably still unsafe since process groups are generally not thought of as being thread safe, but might be a good workaround if any users are facing this issue for now. The only line changes below is creating a new group with group = dist.new_group(...) and then passing this group to the async_save call

import os
import torch
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn as nn
import shutil

CHECKPOINT_DIR = "test_checkpoint"

class TestDummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(0)
        self.net1 = nn.Linear(8, 16)
        self.net2 = nn.Linear(16, 32)
        self.net3 = nn.Linear(32, 64)
        self.net4 = nn.Linear(64, 8)

    def forward(self, x):
        x = F.relu(self.net1(x))
        x = F.relu(self.net2(x))
        x = F.relu(self.net3(x))
        x = F.relu(self.net4(x))
        return x


def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    device = "cuda"
    #device = "cpu"
    
    dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
    group = dist.new_group(ranks=list(range(dist.get_world_size())), backend="gloo")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)

    model = TestDummyModel().to(device)
    state_dict = model.state_dict()

    f = dcp.async_save(state_dict, checkpoint_id=CHECKPOINT_DIR, process_group=group)

    # dist.all_reduce(torch.zeros((1,), device=device)) for some reason, running this first
    # makes the next line not hang.
    dist.barrier()
    if isinstance(f.result(), Exception):
        print(f.result())
    else:
        print("finished saving")


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running on {world_size} devices.")
    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)

    mp.spawn(
        run,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

Another interesting note, when I circled back to replicate this it did not pop up unless I changed the rel. lines to:

for idx in range(10):
        print(f"{dist.get_rank()=} {idx=} Calling async save")
        time.sleep(dist.get_rank())
        f = dcp.async_save(
            state_dict, checkpoint_id=CHECKPOINT_DIR, process_group=group
        )
        time.sleep(dist.get_rank())
        print(f"{dist.get_rank()=} {idx=} Calling barrier")
        dist.barrier()
        dist.barrier()
        dist.barrier()

@LucasLLC
Copy link
Contributor Author

@wconstab , @kwen2501

From this dump, it looks like this is a true collective hang (both threads are waiting at the same time):

Thread 3642049 (idle): "MainThread"
    barrier (torch/distributed/distributed_c10d.py:3694)
    wrapper (torch/distributed/c10d_logger.py:78)
    run (playground.py:60)
    _wrap (torch/multiprocessing/spawn.py:75)
    run (multiprocessing/process.py:108)
    _bootstrap (multiprocessing/process.py:314)
    _main (multiprocessing/spawn.py:135)
    spawn_main (multiprocessing/spawn.py:122)
    <module> (<string>:1)
Thread 3655815 (idle): "ThreadPoolExecutor-0_0"
    all_gather (torch/distributed/distributed_c10d.py:2865)
    wrapper (torch/distributed/c10d_logger.py:78)
    gather_object (torch/distributed/distributed_c10d.py:2542)
    wrapper (torch/distributed/c10d_logger.py:78)
    gather_object (torch/distributed/checkpoint/utils.py:106)
    reduce_scatter (torch/distributed/checkpoint/utils.py:167)
    _save_state_dict (torch/distributed/checkpoint/state_dict_saver.py:286)
    save (torch/distributed/checkpoint/state_dict_saver.py:148)
    inner_func (torch/distributed/checkpoint/utils.py:427)
    wrapper (torch/distributed/checkpoint/logger.py:64)
    run (concurrent/futures/thread.py:58)
    _worker (concurrent/futures/thread.py:83)
    run (threading.py:975)
    _bootstrap_inner (threading.py:1038)
    _bootstrap (threading.py:995)

@wconstab
Copy link
Contributor

ok. if the problem was a hang that was caused by a race between the allgather and the barrier, then the solution is to put the allgather and the barrier on separate logical streams, rather than having them be sequential.

Whether it is thread-safe to use a PG created in one thread on another thread, I think we could get more input on from @kwen2501 @minsii. We are creating a new pg via new_group in the main thread, and it is dedicated for use in the side thread, so we do not ever use that PG from 2 threads. However under the hood if that PG is sharing some nccl communicator resources with the other PGs, we may still be violating thread safety of nccl.

@ajWithNucleus
Copy link

ajWithNucleus commented May 6, 2025

This problem continues to exist event with pytorch 2.7 release. Any solution/fix for this other than the proposed work around ?

@kwen2501 @minsii @wconstab

@minsii
Copy link
Contributor

minsii commented May 7, 2025

First looking into the hanging stack, looks like there is a concurrent allgather (inside dcp.async_save on Thread 3655815) and a barrier (on main thread). There seem two layers problem to me:

  1. At NCCL layer, I don't think NCCL provides any thread safety when two threads calling into the same communicator, due to internal shared resource (see https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html)
  2. At NCCL PG layer, I think an internal stream (from Pytorch stream pool) is assigned to the given PG. Thus, when the two CPU threads uses the same PG, they would schedule collective kernels to the same stream, which would also cause deadlock. @kwen2501 may know it more precisely in latest PG code.

To address the issue, I think we'd use two separate PGs, one passed into dcp, and the other can be used by the main thread.

@fegin
Copy link
Contributor

fegin commented May 8, 2025

Thanks @minsii , @ajWithNucleus dcp.async_save() does provide an argument for passing the process group. Could you try it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants