-
Notifications
You must be signed in to change notification settings - Fork 24.1k
[device_mesh] improve device selection logic #150897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/wanchaol/370/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150897
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 0f1daaa with merge base 843e4d1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry not having enough context on DeviceMesh, so asking some questions before I can review. Meanwhile @fegin if he could unblock.
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" | ||
) | ||
device_handle.set_device(get_rank() % num_devices_per_host) | ||
if device_handle and not device_handle.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with the code here. Could you elaborate a bit what are being done here?
It looks like
- before this PR, the code would be executed if
device_handle
exists and c10dis_initialized()
is false. - after this PR, the code would be executed if
device_handle
exists anddevice_handle
'sis_initialized()
is false.
How is this related to
when user did not set the device before calling the DeviceMesh constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is_initialized()
is either torch.cuda.is_initialized()
or torch.cpu.is_initialized()
. For CUDA, this means the device state is initialized, iiuc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I also updated the PR summary to reflect the detailed changes this PR did.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to your PR desc, we don't do this set_device
If user call init_process_group to init a world process group and the default cuda context is initialized
To translate it into this condition, wouldn't it be?
if device_handle and not device_handle.is_initialized(): | |
if not default_initialized or (device_handle and not device_handle.is_initialized()): |
# NOTE: This device selection would only work for homogeneous hardware. | ||
num_devices_per_host = device_handle.device_count() | ||
if ( | ||
world_size > num_devices_per_host |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you educate on the relationship between world_size
and device_handle
vs. global is_initialized
call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_handle
is just for device agnostic code (i.e. user might have cuda/xpu/mtia). world_size
and is_initialized
is to check if the default world process group is initialized or not.
and serves as a proxy for communication among the device lists within the cluster. | ||
DeviceMesh could be used to setup the N dimensional device connections across the cluster, | ||
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on | ||
each dimension of the DeviceMesh separately. DeviceMesh respect the device that user select |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each dimension of the DeviceMesh separately. DeviceMesh respect the device that user select | |
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user select |
|
||
DeviceMesh can be used as a context manager. | ||
DeviceMesh can also be used as a context manager when using together with most DTensor APIs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, what are the APIs that this doesn't work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc every public DTensor API should work, I just want to be conservative here. All non-DTensor APIs would probably not work because the context manager is just stashing the device mesh created to the mesh env, nothing magic
Stack from ghstack (oldest at bottom):
as titled, this PR improves the device selection logic when user did not
set the device before calling the DeviceMesh constructor, as a device
manager, DeviceMesh should try to set the device for users in a good
way.
The behavior of set_device before:
init_process_group
to init a world process group, we assume user already called set_device and we don't set the device for the userThis is ok but sometimes the set_device heuristic wouldn't work well (i.e. if user use
TORCH_CUDA_VISBILE_DEVICES
So this PR improves the device selection logic to:
init_process_group
to init a world process group and the default cuda context is initialized, then we assume user must called some cuda operation before therefore must have selected the device by themselvesTORCH_CUDA_VISBILE_DEVICES
issue)cc @H-Huang @awgu @fegin @fduwjj @wz337 @wconstab @d4l3k @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168