This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Optimize NMS part 2#14352
Merged
Merged
Conversation
Contributor
|
@mxnet-label-bot add [Operator, pr-awaiting-review] |
zhreshold
approved these changes
Mar 7, 2019
7 tasks
arcadiaphy
reviewed
Mar 7, 2019
| int num_batch) { | ||
| size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (tid < N) { | ||
| const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1; |
Member
There was a problem hiding this comment.
Using __ldg intrinsic will fail to compile on some early cuda architectures.
Member
Author
There was a problem hiding this comment.
It will fail on sm 3.0 and earlier (so Fermi and the first Kepler). I can put ifdef there, but do we care about those?
Member
There was a problem hiding this comment.
In Makefile, sm 30 is in KNOWN_CUDA_ARCHS.
https://github.com/apache/incubator-mxnet/blob/master/Makefile#L385
Member
Author
There was a problem hiding this comment.
Then we do ;-). I will introduce the guard, thanks!
vdantu
pushed a commit
to vdantu/incubator-mxnet
that referenced
this pull request
Mar 31, 2019
* Optimize NMS part 2 * Guarding ldg intrinsics
nswamy
pushed a commit
that referenced
this pull request
Apr 5, 2019
* Optimize NMS part 2 * Guarding ldg intrinsics
haohuanw
pushed a commit
to haohuanw/incubator-mxnet
that referenced
this pull request
Jun 23, 2019
* Optimize NMS part 2 * Guarding ldg intrinsics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR changes the
batch_startcalculation in the BoxNMSForward op to the custom kernel, much faster than the mshadow generated one. In MaskRCNN model it changes the runtime of that part from 20 ms to 2 us, speeding up the single GPU training by 20% in fp16 mode.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Comments