-
Notifications
You must be signed in to change notification settings - Fork 207
Add support for Warp Shuffle-based reductions for arch >= 3.0 #750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
killeent
wants to merge
10
commits into
torch:master
Choose a base branch
from
killeent:min-smem
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
23e7187
block wide reduction with multiple values to reduce at once
killeent 2592252
add function for determining the amount of smem required for a reduct…
killeent e09788f
move random functions to use reduce smem, round up to smem multiple o…
killeent c2717fe
use reduceSmemSize in reduceAll, round up to warpsize multiple thread…
killeent abbdfc5
pass thcstate to reduceSmemSize
killeent c7e8c90
small changes to mode --> make sure we always have at least a warp # …
killeent 720d24a
small fix in nthreadlocal; add doc
killeent 875601f
implement warp shuffle based reduction; enable for arch >= 3.0
killeent 28cb961
rebase
killeent 576ae9f
only first warp executes last stage of warp reduction
killeent File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the old or new code here.
Isn't
numValsalways less than blockDim.x, because it is the number of threads with N active values? In other words, it has nothing to do withN, because allnumValsthreads have N values?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the new code:
numValsis the overall slice size for the reduction. So IfnumValsis 512 andN= 2 then the first 256 threads in the block have valid input values. ForreduceBlock, numVals represents the number of threads whose values should be considered valid. So in the above example, the first 256 threads have values that should be reduced, hence the division.The old code was doing something incorrect. However, because the local reduction uses init, it still succeeded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but then what about the case where numVals is not a multiple of N?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case the
THCCeilDivshould round it up - if we had 513 values in the above case then the first 257 threads should have valid input values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but then only one of the 2 input values for the 257th thread would then be valid? where's the guarantee that all provided input values are either valid reduction values, or identity values? if they're identity values for the reduction, then we don't really need numVals at all, except possibly to handle tiny reduction cases (smaller than the block size) for slightly greater efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just means its part of the input that we need to consider. The resulting value from the 257th thread is the combination of the input value and a identity value, but the presence of the input value means it must take part in the reduction.
Yes, in theory we could imagine a case where block size > numVals -> this is common in the code for mode, example, where we round up.