Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Move mps_linear forward to use MPS kernels directly instead of MPSGraph #152210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

jhavukainen
Copy link
Collaborator

@jhavukainen jhavukainen commented Apr 25, 2025

This PR moves mps_linear to use MPSNDArrays and call into the MPS kernel directly instead of going through MPSGraph. It also adds a caching mechanism for reusing MPS kernels as there is also a small overhead attached to creating the kernel object.

The impact of the improvement is relatively more significant for small input kernels where the MPSGraph overhead represents a larger portion of the overall execution time of the operation but the speedup shows for both small and large input sizes as expected.

mps_linear before the changes:

input shapes: f32:[1,1,20], f32:[1,20]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x109d67110>
func(*args, **kwargs)
  Median: 199.29 us
  IQR:    9.56 us (196.71 to 206.27)
  979 measurements, 1 runs per measurement, 1 thread

input shapes: f32:[1,1,5120], f32:[13284,5120]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x1063b4510>
func(*args, **kwargs)
  Median: 979.29 us
  IQR:    25.29 us (964.83 to 990.13)
  205 measurements, 1 runs per measurement, 1 thread

mps_linear after the changes:

input shapes: f32:[1,1,20], f32:[1,20]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10693a190>
func(*args, **kwargs)
  Median: 176.08 us
  IQR:    15.02 us (172.42 to 187.44)
  1103 measurements, 1 runs per measurement, 1 thread

input shapes: f32:[1,1,5120], f32:[13284,5120]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10d524dd0>
func(*args, **kwargs)
  Median: 952.56 us
  IQR:    15.63 us (945.47 to 961.10)
  210 measurements, 1 runs per measurement, 1 thread

cc @kulinseth @albanD @malfet @DenisVieriu97

@jhavukainen jhavukainen added the module: mps Related to Apple Metal Performance Shaders framework label Apr 25, 2025
Copy link

pytorch-bot bot commented Apr 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152210

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 20 Pending, 1 Unrelated Failure

As of commit 8727017 with merge base 1d3e8f3 (image):

NEW FAILURE - The following job has failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Apr 25, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
@jhavukainen jhavukainen force-pushed the dev/joona/mps_linear branch from 49b4351 to 76a1b48 Compare May 2, 2025 19:46
@jhavukainen jhavukainen requested a review from malfet May 6, 2025 21:17
@jhavukainen jhavukainen force-pushed the dev/joona/mps_linear branch from acaeb1d to 06a28bb Compare May 8, 2025 00:23
Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@jhavukainen
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/152210/head returned non-zero exit code 1

Rebasing (1/5)
hint: Recursive merging with submodules currently only supports trivial cases.
hint: Please manually handle the merging of each conflicted submodule.
hint: This can be accomplished with the following steps:
hint:  - come back to superproject and run:
hint:
hint:       git add third_party/cutlass
hint:
hint:    to record the above merge or update
hint:  - resolve any other conflicts in the superproject
hint:  - commit the resulting index in the superproject
hint:
hint: Disable this message with "git config set advice.submoduleMergeConflict false"
Failed to merge submodule third_party/cutlass (not checked out)
CONFLICT (submodule): Merge conflict in third_party/cutlass
error: could not apply 657bf1e0646... Adding a direct MPS kernel path to linear op and MPS kernel caching mechanism for improved perf.
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 657bf1e0646... Adding a direct MPS kernel path to linear op and MPS kernel caching mechanism for improved perf.

Raised by https://github.com/pytorch/pytorch/actions/runs/14910649119

@jhavukainen jhavukainen force-pushed the dev/joona/mps_linear branch from 06a28bb to f855f95 Compare May 8, 2025 15:58
Comment on lines 116 to 119
bool is_macos_15_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if (is_macos_15_or_newer) {
_mps_linear_nograph(input, weight, bias, output);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it would be nice to just replace it with if() { return;} (Doing it now)

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, though feels a bit like too much boilerplate code, but will look into it later

@malfet
Copy link
Contributor

malfet commented May 9, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged module: mps Related to Apple Metal Performance Shaders framework open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants