Add support for MPS devices in CTC prefix scoring#6266
Add support for MPS devices in CTC prefix scoring#6266sw005320 merged 1 commit intoespnet:masterfrom
Conversation
Extend device handling to include MPS alongside CUDA and CPU, ensuring proper device assignment and tensor operations.
There was a problem hiding this comment.
Code Review
This pull request correctly adds support for MPS devices in CTC prefix scoring, which resolves a device mismatch issue. The changes are effective and address the problem described. However, there's an opportunity to simplify the device detection logic. By using x.device directly, the code can be made more concise, maintainable, and robust against future changes in PyTorch device types, aligning it better with modern practices seen elsewhere in the codebase.
| # (KanTakahiro): Support for CUDA, MPS, and CPU devices | ||
| if x.is_cuda: | ||
| self.device = torch.device("cuda:%d" % x.get_device()) | ||
| elif x.device.type == "mps": | ||
| self.device = x.device | ||
| else: | ||
| self.device = torch.device("cpu") |
There was a problem hiding this comment.
This block for device detection can be significantly simplified. Using x.device directly is more concise, readable, and robust. It will correctly handle CUDA, CPU, MPS, and any other devices PyTorch might support in the future without requiring additional elif branches.
| # (KanTakahiro): Support for CUDA, MPS, and CPU devices | |
| if x.is_cuda: | |
| self.device = torch.device("cuda:%d" % x.get_device()) | |
| elif x.device.type == "mps": | |
| self.device = x.device | |
| else: | |
| self.device = torch.device("cpu") | |
| self.device = x.device |
There was a problem hiding this comment.
In terms of programming, this is indeed a more concise and elegant approach. However, I'm concerned that some newly added devices in PyTorch may not be supported by espnet, potentially leading to unexpected errors. Therefore, I opted to manually write the logic for selecting the device to ensure that the final chosen device is always supported by espnet.
sw005320
left a comment
There was a problem hiding this comment.
Thanks for the PR.
This looks good to me.
After the CI passes, I'll merge this PR.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #6266 +/- ##
==========================================
- Coverage 56.77% 56.77% -0.01%
==========================================
Files 889 889
Lines 84361 84365 +4
==========================================
+ Hits 47898 47900 +2
- Misses 36463 36465 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Thanks a lot! |
Extend device handling to include MPS alongside CUDA and CPU, ensuring proper device assignment and tensor operations.
What did you change?
Why did you make this change?
The original code did not properly detect MPS (Metal Performance Shaders) devices on Apple Silicon Macs. This caused a device mismatch error during inference when the input tensor was on an MPS device. Specifically, when x.is_cuda returned False for MPS devices, self.device was incorrectly set to cpu. This led to some tensors being created on the CPU while others were on the MPS device, resulting in the error: "Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!".
The modifications ensure that MPS devices are correctly detected and that all tensors are created on the appropriate device, thus avoiding device mismatch errors and enabling proper functionality on Apple Silicon Macs.
Is your PR small enough?
Yes
Additional Context
Relative issue: #6264