Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[device_mesh] improve device selection logic #150897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: gh/wanchaol/370/base
Choose a base branch
from

Conversation

wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Apr 9, 2025

Stack from ghstack (oldest at bottom):

as titled, this PR improves the device selection logic when user did not
set the device before calling the DeviceMesh constructor, as a device
manager, DeviceMesh should try to set the device for users in a good
way.

The behavior of set_device before:

  • If user call init_process_group to init a world process group, we assume user already called set_device and we don't set the device for the user
  • If user does not init a world process group by themselves, we init a world process group for the user and follow a heuristic to set the device.
    This is ok but sometimes the set_device heuristic wouldn't work well (i.e. if user use TORCH_CUDA_VISBILE_DEVICES

So this PR improves the device selection logic to:

  • If user call init_process_group to init a world process group and the default cuda context is initialized, then we assume user must called some cuda operation before therefore must have selected the device by themselves
  • If not the above, then we check if envvars have "LOCAL_RANK" and "WORLD_SIZE" from the launcher (i.e. torchrun), if so, we use "LOCAL_RANK" to set the device for the current process, which is a very standard practice. (This solves the TORCH_CUDA_VISBILE_DEVICES issue)
  • If not above, then we fallback to the old heuristic.

cc @H-Huang @awgu @fegin @fduwjj @wz337 @wconstab @d4l3k @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150897

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 0f1daaa with merge base 843e4d1 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 9, 2025
@wanchaol wanchaol added the release notes: distributed (dtensor) release notes category label Apr 9, 2025
wanchaol added 2 commits April 9, 2025 18:13
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Apr 21, 2025
wanchaol added a commit that referenced this pull request Apr 21, 2025
as titled, this PR improves the device selection logic when user did not
set the device before calling the DeviceMesh constructor, as a device
manager, DeviceMesh should try to set the device for users in a good
way.

ghstack-source-id: 8d27c0d
Pull Request resolved: #150897
[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Apr 21, 2025
as titled, this PR improves the device selection logic when user did not
set the device before calling the DeviceMesh constructor, as a device
manager, DeviceMesh should try to set the device for users in a good
way.

ghstack-source-id: 7967a39
Pull Request resolved: #150897
[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Apr 21, 2025
as titled, this PR improves the device selection logic when user did not
set the device before calling the DeviceMesh constructor, as a device
manager, DeviceMesh should try to set the device for users in a good
way.

ghstack-source-id: 55e85d1
Pull Request resolved: #150897
@wanchaol wanchaol added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 21, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry not having enough context on DeviceMesh, so asking some questions before I can review. Meanwhile @fegin if he could unblock.

f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
)
device_handle.set_device(get_rank() % num_devices_per_host)
if device_handle and not device_handle.is_initialized():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with the code here. Could you elaborate a bit what are being done here?
It looks like

  • before this PR, the code would be executed if device_handle exists and c10d is_initialized() is false.
  • after this PR, the code would be executed if device_handle exists and device_handle's is_initialized() is false.

How is this related to

when user did not set the device before calling the DeviceMesh constructor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is_initialized() is either torch.cuda.is_initialized() or torch.cpu.is_initialized(). For CUDA, this means the device state is initialized, iiuc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also updated the PR summary to reflect the detailed changes this PR did.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to your PR desc, we don't do this set_device

If user call init_process_group to init a world process group and the default cuda context is initialized

To translate it into this condition, wouldn't it be?

Suggested change
if device_handle and not device_handle.is_initialized():
if not default_initialized or (device_handle and not device_handle.is_initialized()):

# NOTE: This device selection would only work for homogeneous hardware.
num_devices_per_host = device_handle.device_count()
if (
world_size > num_devices_per_host
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you educate on the relationship between world_size and device_handle vs. global is_initialized call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_handle is just for device agnostic code (i.e. user might have cuda/xpu/mtia). world_size and is_initialized is to check if the default world process group is initialized or not.

and serves as a proxy for communication among the device lists within the cluster.
DeviceMesh could be used to setup the N dimensional device connections across the cluster,
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
each dimension of the DeviceMesh separately. DeviceMesh respect the device that user select
Copy link
Contributor

@fegin fegin Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
each dimension of the DeviceMesh separately. DeviceMesh respect the device that user select
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user select


DeviceMesh can be used as a context manager.
DeviceMesh can also be used as a context manager when using together with most DTensor APIs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, what are the APIs that this doesn't work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc every public DTensor API should work, I just want to be conservative here. All non-DTensor APIs would probably not work because the context manager is just stashing the device mesh created to the mesh env, nothing magic

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request module: cpu CPU specific problem (e.g., perf, algorithm) oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants