-
Notifications
You must be signed in to change notification settings - Fork 13.7k
vulkan: implement ADD1, ARANGE, FILL, SOFTPLUS, STEP, ROUND, CEIL, FLOOR, TRUNC #17319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
67689fd to
e59509c
Compare
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
| CREATE_UNARY_RTE(exp) | ||
| #undef CREATE_UNARY_RTE | ||
|
|
||
| ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each invocation does 2 iterations and the WG has 256 threads, so I think this should be 512?
| layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; | ||
|
|
||
| void main() { | ||
| const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these index calculations rely on this logic in ggml_vk_op_f32, but arange doesn't go through ggml_vk_op_f32:
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}
e59509c to
26eeb8a
Compare
|
@jeffbolznv thanks for the review. Addressed and pushed a new version |
| layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; | ||
|
|
||
| void main() { | ||
| const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this one has the same issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed as well! Thanks
26eeb8a to
d7df09c
Compare
jeffbolznv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me now. Do you think it makes sense to add a few larger test cases for these to catch those bugs?
do you've any suggestions on what to add? I've tried a few new cases with more elements, but they don' fail with the older version of the PR |
|
Nothing specific, just that they would need to be more than 256k elements. |
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
d7df09c to
2300634
Compare
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2300634 to
15bcb5e
Compare
added new test cases |
mostly mechanical changes, except ROUND that doesn't match directly to Vulkan as there is no equivalent rounding mode (at least didn't manage to find it)