Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add support for MPS devices in CTC prefix scoring#6266

Merged
sw005320 merged 1 commit intoespnet:masterfrom
KanTakahiro:fix/CTCPrefixScoreTH-MPS-support
Oct 20, 2025
Merged

Add support for MPS devices in CTC prefix scoring#6266
sw005320 merged 1 commit intoespnet:masterfrom
KanTakahiro:fix/CTCPrefixScoreTH-MPS-support

Conversation

@KanTakahiro
Copy link
Contributor

Extend device handling to include MPS alongside CUDA and CPU, ensuring proper device assignment and tensor operations.

What did you change?

  1. Added Support for MPS Devices: The original code only supported CUDA and CPU devices. The modified code adds support for MPS (Metal Performance Shaders) devices, which allows for GPU acceleration on Mac devices. This makes the code more versatile and efficient across different hardware platforms.
  2. Ensured Tensor Creation on Correct Device: The modified code ensures that tensors are created on the correct device by specifying the device in the torch.as_tensor function. This avoids potential issues with device mismatch, ensuring that all tensor operations are performed on the same device.

Why did you make this change?

The original code did not properly detect MPS (Metal Performance Shaders) devices on Apple Silicon Macs. This caused a device mismatch error during inference when the input tensor was on an MPS device. Specifically, when x.is_cuda returned False for MPS devices, self.device was incorrectly set to cpu. This led to some tensors being created on the CPU while others were on the MPS device, resulting in the error: "Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!".
The modifications ensure that MPS devices are correctly detected and that all tensors are created on the appropriate device, thus avoiding device mismatch errors and enabling proper functionality on Apple Silicon Macs.


Is your PR small enough?

Yes


Additional Context

Relative issue: #6264

Extend device handling to include MPS alongside CUDA and CPU, ensuring
proper device assignment and tensor operations.
@dosubot dosubot bot added size:S This PR changes 10-29 lines, ignoring generated files. ASR Automatic speech recogntion labels Oct 20, 2025
@mergify mergify bot added the ESPnet1 label Oct 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds support for MPS devices in CTC prefix scoring, which resolves a device mismatch issue. The changes are effective and address the problem described. However, there's an opportunity to simplify the device detection logic. By using x.device directly, the code can be made more concise, maintainable, and robust against future changes in PyTorch device types, aligning it better with modern practices seen elsewhere in the codebase.

Comment on lines +41 to +47
# (KanTakahiro): Support for CUDA, MPS, and CPU devices
if x.is_cuda:
self.device = torch.device("cuda:%d" % x.get_device())
elif x.device.type == "mps":
self.device = x.device
else:
self.device = torch.device("cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block for device detection can be significantly simplified. Using x.device directly is more concise, readable, and robust. It will correctly handle CUDA, CPU, MPS, and any other devices PyTorch might support in the future without requiring additional elif branches.

Suggested change
# (KanTakahiro): Support for CUDA, MPS, and CPU devices
if x.is_cuda:
self.device = torch.device("cuda:%d" % x.get_device())
elif x.device.type == "mps":
self.device = x.device
else:
self.device = torch.device("cpu")
self.device = x.device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of programming, this is indeed a more concise and elegant approach. However, I'm concerned that some newly added devices in PyTorch may not be supported by espnet, potentially leading to unexpected errors. Therefore, I opted to manually write the logic for selecting the device to ensure that the final chosen device is always supported by espnet.

Copy link
Contributor

@sw005320 sw005320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.
This looks good to me.
After the CI passes, I'll merge this PR.

@dosubot dosubot bot added the lgtm This PR has been approved by a maintainer label Oct 20, 2025
@codecov
Copy link

codecov bot commented Oct 20, 2025

Codecov Report

❌ Patch coverage is 66.66667% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 56.77%. Comparing base (cc1f815) to head (0f449c0).
⚠️ Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
espnet/nets/ctc_prefix_score.py 66.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6266      +/-   ##
==========================================
- Coverage   56.77%   56.77%   -0.01%     
==========================================
  Files         889      889              
  Lines       84361    84365       +4     
==========================================
+ Hits        47898    47900       +2     
- Misses      36463    36465       +2     
Flag Coverage Δ
test_integration_espnet2 46.80% <66.66%> (+<0.01%) ⬆️
test_integration_espnetez 36.92% <0.00%> (-0.01%) ⬇️
test_python_espnet2 51.20% <66.66%> (-0.01%) ⬇️
test_python_espnetez 12.81% <0.00%> (-0.01%) ⬇️
test_utils 18.77% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sw005320 sw005320 merged commit 10c354e into espnet:master Oct 20, 2025
61 of 77 checks passed
@sw005320
Copy link
Contributor

Thanks a lot!

@Fhrozen Fhrozen modified the milestones: v.202512, v.202511 Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ASR Automatic speech recogntion ESPnet1 lgtm This PR has been approved by a maintainer size:S This PR changes 10-29 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants