Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@jagrit06
Copy link
Member

@jagrit06 jagrit06 commented Jan 27, 2025

Proposed changes

  • Build in padding to Winograd kernels
  • Add new fused Winograd kernel
  • Enable weight flipping in Winograd kernels

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jagrit06
Copy link
Member Author

This update on its own should not help training benchmarks like cifar since the added kernel does not do well with large batch sizes. Further updates focused on batches should help that. In the meantime, this should improve batch size = 1 workloads and also reduce at least 1 copy of the inputs that might be used for padding

The numbers below compare perf to PyTorch in the last column

M3 Max Before:
(4, 128, 128, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.343, -21.22%
(4, 128, 128, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 1.013, -26.64%
(256, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.352, -27.20%
(256, 8, 8, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.340, -26.26%
(4, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 2.788, -5.70%
(1, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.074, +47.40%
(1, 16, 16, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.106, -2.51%
(1, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.780, -7.04%
(1, 128, 128, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 2.261, +16.55%
(1, 16, 16, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.157, -26.23%
(1, 16, 16, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 0.274, -27.10%

M3 Max After:
(4, 128, 128, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.342, -27.23%
(4, 128, 128, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 1.014, -26.30%
(256, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.346, -26.63%
(256, 8, 8, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.340, -25.95%
(4, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 2.179, +20.46%
(1, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.052, +108.45%
(1, 16, 16, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.059, +64.23%
(1, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.692, +4.99%
(1, 128, 128, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 1.634, +58.94%
(1, 16, 16, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.082, +51.14%
(1, 16, 16, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 0.158, +23.45%

M2 Ultra Before:
(4, 128, 128, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.271, -27.74%
(4, 128, 128, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.750, -35.50%
(256, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.272, -25.79%
(256, 8, 8, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.241, -22.92%
(4, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 1.972, -23.75%
(1, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.072, +45.15%
(1, 16, 16, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.121, -15.60%
(1, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.720, -32.91%
(1, 128, 128, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 2.147, -24.63%
(1, 16, 16, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.197, -32.59%
(1, 16, 16, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 0.355, -39.01%

M2 Ultra After:
(4, 128, 128, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.271, -27.77%
(4, 128, 128, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.750, -35.65%
(256, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.271, -25.83%
(256, 8, 8, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.241, -23.95%
(4, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 1.650, -8.85%
(1, 16, 16, 32), (32, 3, 3, 32), float32, (1, 1), (1, 1), 0.046, +122.89%
(1, 16, 16, 64), (64, 3, 3, 64), float32, (1, 1), (1, 1), 0.055, +84.06%
(1, 128, 128, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.572, -15.53%
(1, 128, 128, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 1.258, +28.53%
(1, 16, 16, 128), (128, 3, 3, 128), float32, (1, 1), (1, 1), 0.076, +67.96%
(1, 16, 16, 256), (256, 3, 3, 256), float32, (1, 1), (1, 1), 0.139, +55.04%

@awni
Copy link
Member

awni commented Jan 27, 2025

I ran this ResNet inference benchmark on M2 Ultra.

Pre/post below.

TLDR nice speedup on batch size = 1, slight slow down on 32+, is that expected?

Batch Size Images-per-second Milliseconds-per-image
1 608.598 1.643
2 1147.084 0.872
4 1426.277 0.701
8 1869.651 0.535
16 2133.113 0.469
32 2308.229 0.433
64 2529.965 0.395
Batch Size Images-per-second Milliseconds-per-image
1 747.178 1.338
2 1146.119 0.873
4 1539.774 0.649
8 1921.938 0.520
16 2035.330 0.491
32 2062.188 0.485
64 2106.746 0.475

@awni
Copy link
Member

awni commented Jan 27, 2025

since the added kernel does not do well with large batch sizes. Further updates focused on batches should help that.

I guess that's what you meant. Should we just dispatch to the old kernel with batch size 32+?

@jagrit06
Copy link
Member Author

since the added kernel does not do well with large batch sizes. Further updates focused on batches should help that.

I guess that's what you meant. Should we just dispatch to the old kernel with batch size 32+?

Done! I moved them back to the old routing for now

@jagrit06 jagrit06 requested review from angeloskath and awni and removed request for angeloskath January 28, 2025 19:11
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome job as always :-) The MMATiles result in very readable kernels.

I left a few comments with the main one being an issue in the op with routing to matmul and the rest are either nitpicks or discussion.


// Iterate over C
for (int c = 0; c < params.C; c += BC) {
#define tmp_load_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not implement these outside the kernel?

Also nitpick: h* FA* BC -> h * FA * BC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These ended up here because I was experimenting with the strides we load things in and it ends up a easier scroll to understand near the reading

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment on the routing to matmul. Otherwise looks great!

/*B_batch_stride = */ 0,
/*matrix_stride_out = */ 0,
/*copies = */ empty_copies);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say the if below needs an else if

@jagrit06 jagrit06 merged commit 2dc307f into main Feb 14, 2025
5 checks passed
@jagrit06 jagrit06 deleted the winograd_tmp branch February 14, 2025 21:08
angeloskath added a commit that referenced this pull request Feb 18, 2025
angeloskath added a commit that referenced this pull request Feb 18, 2025
faisalmemon pushed a commit to faisalmemon/mlx that referenced this pull request Oct 30, 2025
* Build in padding to Winograd kernels
* Add new fused Winograd kernel
* Enable weight flipping in Winograd kernels
faisalmemon pushed a commit to faisalmemon/mlx that referenced this pull request Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants