Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Move mps_linear forward to use MPS kernels directly instead of MPSGraph #152210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jhavukainen
Copy link
Collaborator

@jhavukainen jhavukainen commented Apr 25, 2025

This PR moves mps_linear to use MPSNDArrays and call into the MPS kernel directly instead of going through MPSGraph. It also adds a caching mechanism for reusing MPS kernels as there is also a small overhead attached to creating the kernel object.

The impact of the improvement is relatively more significant for small input kernels where the MPSGraph overhead represents a larger portion of the overall execution time of the operation but the speedup shows for both small and large input sizes as expected.

mps_linear before the changes:

input shapes: f32:[1,1,20], f32:[1,20]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x109d67110>
func(*args, **kwargs)
  Median: 199.29 us
  IQR:    9.56 us (196.71 to 206.27)
  979 measurements, 1 runs per measurement, 1 thread

input shapes: f32:[1,1,5120], f32:[13284,5120]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x1063b4510>
func(*args, **kwargs)
  Median: 979.29 us
  IQR:    25.29 us (964.83 to 990.13)
  205 measurements, 1 runs per measurement, 1 thread

mps_linear after the changes:

input shapes: f32:[1,1,20], f32:[1,20]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10693a190>
func(*args, **kwargs)
  Median: 176.08 us
  IQR:    15.02 us (172.42 to 187.44)
  1103 measurements, 1 runs per measurement, 1 thread

input shapes: f32:[1,1,5120], f32:[13284,5120]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10d524dd0>
func(*args, **kwargs)
  Median: 952.56 us
  IQR:    15.63 us (945.47 to 961.10)
  210 measurements, 1 runs per measurement, 1 thread

cc @kulinseth @albanD @malfet @DenisVieriu97

@jhavukainen jhavukainen added the module: mps Related to Apple Metal Performance Shaders framework label Apr 25, 2025
Copy link

pytorch-bot bot commented Apr 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152210

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Apr 25, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
@jhavukainen jhavukainen force-pushed the dev/joona/mps_linear branch from 49b4351 to 76a1b48 Compare May 2, 2025 19:46
@jhavukainen jhavukainen requested a review from malfet May 6, 2025 21:17
@jhavukainen jhavukainen force-pushed the dev/joona/mps_linear branch from acaeb1d to 06a28bb Compare May 8, 2025 00:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) module: mps Related to Apple Metal Performance Shaders framework open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants