Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Failed visualized 1D DTensor #152848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wangkuiyi opened this issue May 5, 2025 · 1 comment
Closed

Failed visualized 1D DTensor #152848

wangkuiyi opened this issue May 5, 2025 · 1 comment
Assignees
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@wangkuiyi
Copy link
Contributor

wangkuiyi commented May 5, 2025

πŸ› Describe the bug

To reproduce this issue, run the follow script using torchrun:

import torch.distributed.tensor.debug
import torch.distributed.tensor as dt
import torch.distributed as dist
import os

rank = int(os.getenv("RANK", "0"))


def render(t, msg):
    if rank == 0:
        print(msg)
        dt.debug.visualize_sharding(t, use_rich=False)


m = dist.init_device_mesh("cuda", (4,))
t = dt.distribute_tensor(torch.ones(4), m, [dt.Shard(dim=0)])
dt.debug.visualize_sharding(t, use_rich=True)
t = dt.distribute_tensor(torch.ones(4), m, [dt.Replicate()])
dt.debug.visualize_sharding(t, use_rich=True)

The first call to visualize_sharding would fail with

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/w/l/s.py", line 17, in <module>
[rank0]:     dt.debug.visualize_sharding(t, use_rich=True)
[rank0]:   File "/root/w/pytorch/torch/distributed/tensor/debug/_visualize_sharding.py", line 201, in visualize_sharding
[rank0]:     (offset[1], offset[1] + shape[1] - 1),
[rank0]:      ~~~~~~^^^
[rank0]: IndexError: tuple index out of range

Versions

# pip list | grep torch
torch              2.8.0a0+git730a077 /root/w/pytorch

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @tianyu-l @XilunWu

@wangkuiyi
Copy link
Contributor Author

I am afraid that this issue was introduced by my previous PR #152027. I will take a look into it. Please feel free to assign this issue to me.

@wanchaol wanchaol added the module: dtensor distributed tensor tag label May 5, 2025
@zou3519 zou3519 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants