-
Notifications
You must be signed in to change notification settings - Fork 839
[DispatchCreation] Set split reduction size for GEMM with large k dim #22357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: yzhang93 <[email protected]>
f74c75e to
3bdf63c
Compare
Signed-off-by: yzhang93 <[email protected]>
| // RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-split-reduction-sizes))" --split-input-file > %t | ||
| // RUN: FileCheck %s < %t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should pipe to filecheck
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, sure. This was simply copied from other tests , but I can modify those as well.
| for (int64_t i = 0; i < tileSizes.size(); i++) { | ||
| int64_t lowerBound = llvm::divideCeil(tileSizes[i], limitParallelLoops); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use enumerate(tileSizes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few nit, and clarifications. Otherwise looks good.
| return std::nullopt; | ||
| } | ||
|
|
||
| if (linalgOp.getNumParallelLoops() < 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? You plan to revisit that later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to guarantee that the op is matmul-like, similar to this
| linalgOp.getNumParallelLoops() >= 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is more for naming purposes. It's fine for now, but if there is a batch dimension that will be three parallel loops (and you are accounting for that below)
| SmallVector<int64_t> tileSizes = std::move(*maybeSizes); | ||
| int64_t outputSize = mSize * nSize * batchSize; | ||
| int64_t limitParallelLoops; | ||
| if (outputSize < 16 * 16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not fully understanding the link between outputSize and limitParallelLoops. The limitParallelLoops seems to be increasing as the outputSize is smaller. I think a better name might help, and I can suggest once I understand it better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
limitParallelLoops is set as the number limit for parallel loops from split reduction which does not include the number of workgroups. So when the output size is small, it is more likely to distribute to less workgroups, and thus we need more parallel loops from split reduction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok that make sense
Signed-off-by: yzhang93 <[email protected]>
…iree-org#22357) This PR adds basic support for setting split reduction size for matmul-like ops with large K dim. Note that the constant thresholds are empirically chosen based on limited data (1x1 filter weight backward convs) and may not generalize to all cases. It's challenging to find a single threshold to apply for all shapes. The bottom line is to improve the performance for extremely large K cases while not to degrade many smaller shapes. --------- Signed-off-by: yzhang93 <[email protected]>
… weight backward convs (#22491) This PR is a follow-up for #22275. It removes the constraint that only splitting input channel dimension, and added support to split across multiple dimensions. The heuristics for setting multi-dimension tile sizes is similar to what is for GEMM #22357. More than half of the tracked weight backward shapes are benefiting from this change. Example runtime comparison for `convbfp16 -n 16 -c 16 -H 225 -W 225 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 4 -t 1 --in_layout NHWC --out_layout NHWC --fil_layout NHWC --iter 100` - Without split reduction: 19352.8 ms - Split only the input channel dimension: 1445.1 ms - Split multiple reduction dimensions: 371.7 ms --------- Signed-off-by: yzhang93 <[email protected]>
…iree-org#22357) This PR adds basic support for setting split reduction size for matmul-like ops with large K dim. Note that the constant thresholds are empirically chosen based on limited data (1x1 filter weight backward convs) and may not generalize to all cases. It's challenging to find a single threshold to apply for all shapes. The bottom line is to improve the performance for extremely large K cases while not to degrade many smaller shapes. --------- Signed-off-by: yzhang93 <[email protected]>
… weight backward convs (iree-org#22491) This PR is a follow-up for iree-org#22275. It removes the constraint that only splitting input channel dimension, and added support to split across multiple dimensions. The heuristics for setting multi-dimension tile sizes is similar to what is for GEMM iree-org#22357. More than half of the tracked weight backward shapes are benefiting from this change. Example runtime comparison for `convbfp16 -n 16 -c 16 -H 225 -W 225 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 4 -t 1 --in_layout NHWC --out_layout NHWC --fil_layout NHWC --iter 100` - Without split reduction: 19352.8 ms - Split only the input channel dimension: 1445.1 ms - Split multiple reduction dimensions: 371.7 ms --------- Signed-off-by: yzhang93 <[email protected]>
…iree-org#22357) This PR adds basic support for setting split reduction size for matmul-like ops with large K dim. Note that the constant thresholds are empirically chosen based on limited data (1x1 filter weight backward convs) and may not generalize to all cases. It's challenging to find a single threshold to apply for all shapes. The bottom line is to improve the performance for extremely large K cases while not to degrade many smaller shapes. --------- Signed-off-by: yzhang93 <[email protected]>
… weight backward convs (iree-org#22491) This PR is a follow-up for iree-org#22275. It removes the constraint that only splitting input channel dimension, and added support to split across multiple dimensions. The heuristics for setting multi-dimension tile sizes is similar to what is for GEMM iree-org#22357. More than half of the tracked weight backward shapes are benefiting from this change. Example runtime comparison for `convbfp16 -n 16 -c 16 -H 225 -W 225 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 4 -t 1 --in_layout NHWC --out_layout NHWC --fil_layout NHWC --iter 100` - Without split reduction: 19352.8 ms - Split only the input channel dimension: 1445.1 ms - Split multiple reduction dimensions: 371.7 ms --------- Signed-off-by: yzhang93 <[email protected]>
This PR adds basic support for setting split reduction size for matmul-like ops with large K dim. Note that the constant thresholds are empirically chosen based on limited data (1x1 filter weight backward convs) and may not generalize to all cases. It's challenging to find a single threshold to apply for all shapes. The bottom line is to improve the performance for extremely large K cases while not to degrade many smaller shapes.