MPS Support for loading float64 models like OWSM as float32#6246
MPS Support for loading float64 models like OWSM as float32#6246sw005320 merged 10 commits intoespnet:masterfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fix to allow loading models with float64 precision on mps devices, which do not support this data type. The approach of loading the model to the CPU first and then transferring it to the mps device with float32 precision is sound and correctly implemented in espnet2/tasks/abs_task.py.
I've added a review comment on the new test file, test/espnet2/bin/test_macos_mps_model.py, to suggest a more robust way of asserting the model's properties after loading. The current assertions are not sufficient to confirm that the fix works as intended.
|
@dosubot, I added a test and it passes locally on my mac, however I've configured it to only run on mac (when mps is available because it is specific to mps) and it doesn't appear there are github actions configured to run the tests on macos. Is this right? If so, is it something I should change or is it in order as is? |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #6246 +/- ##
=======================================
Coverage 56.77% 56.77%
=======================================
Files 889 889
Lines 84337 84340 +3
=======================================
+ Hits 47880 47882 +2
- Misses 36457 36458 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This is because the changes only apply to macos with apple silicon. As mentioned above, the coverage is 100% when the tests are run on a macos device. |
|
The imports are unused but needed to load the necessary fixtures. Can I tell flake8 to allow this so the tests can pass or would it be preferred if I instead duplicated the fixture code to be directly in the file? I could also move the test into the file with the fixtures but this test seems somewhat separate and the guidelines said to prefer to split tests out into separate files |
|
Hmm. The other changes look good to me. |
|
|
Hello @Fhrozen, thanks for the mention! And thank you, @SanderGi, for raising this important point about the macOS tests. I noticed in the repository's It would be beneficial to check this workflow to ensure it's correctly configured to trigger and execute the tests you've added, especially those that are specific to MPS and macOS. Sometimes, new test files or specific conditions might require updates to the CI configuration. If the existing macOS workflow isn't picking up your new tests, or if there are specific environment variables or dependencies needed for MPS that aren't configured, that would be the next step to investigate. Let me know if you'd like me to take a closer look at the |
|
@gemini-code-assist Could you take a closer look at the |
|
Hello @Fhrozen, thanks for the follow-up! I can certainly take a closer look at what might be going on with the For
If you can share the relevant sections of your |
The cause is line 22 of # flake8
"$(dirname $0)"/test_flake8.sh espnetI have instructed flake8 to ignore the falsely flagged unused imports for now. Let me know if we want to tweak the test_python_espnet1 github action to only run flake8 on espnet1 |
The |
|
Thank you for the clarification, @SanderGi! That explains why the tests aren't being run on the macOS CI. Given that the If you decide to proceed with modifying the |
|
Thanks, @SanderGi! |
What did you change?
Changed the code that builds a model when used with the
mpsdevice to first load onto thecpu(since it supports most dtypes like float64), and then load ontompswith a dtype supported bymps(float32).Why did you make this change?
Previously loading a model like OWSM with the
mpsdevice causedTypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.and because the dtype used to be set separately after the device, there was no way to set the right dtype before the model was loaded onto mps and failed. By loading the model onto a valid dtype initially, the dtype can correctly be changed later.Is your PR small enough?
Yes, only 2 files (< 20) and only 47 changed lines (< 1000). One of the files is the test file.
Additional Context
This solves issue #6244