Skip to content

Add missing __syncwarp() in reduce kernel + CUDA 13.0 build fix#15

Merged
JeffreyXiang merged 2 commits intoJeffreyXiang:mainfrom
cuzelac:fix/syncwarp-reduce-kernel
Mar 11, 2026
Merged

Add missing __syncwarp() in reduce kernel + CUDA 13.0 build fix#15
JeffreyXiang merged 2 commits intoJeffreyXiang:mainfrom
cuzelac:fix/syncwarp-reduce-kernel

Conversation

@cuzelac
Copy link
Copy Markdown
Contributor

@cuzelac cuzelac commented Mar 8, 2026

Summary

Two fixes discovered while debugging Trellis2 on NVIDIA Blackwell (RTX 5090, sm_120):

1. Missing __syncwarp() in warp-level reduction

reduce_code_cuda_kernel in neighbor_map.cu performs a warp-level reduction over shared memory without __syncwarp() between iterations. Each iteration reads buf[threadIdx.x + cur_len] after the prior iteration wrote to buf[threadIdx.x]. Without a warp barrier, there is no guarantee the write is visible to other threads before the next read.

While current NVIDIA hardware executes warps in lockstep, this is undefined behavior per the CUDA programming model and may break on future architectures.

2. -allow-unsupported-compiler for CUDA 13.0

CUDA 13.0 with MSVC 2025 (Visual Studio 18) fails to compile without this flag, as nvcc only officially supports MSVC 2019-2022.

Testing

  • Built and tested on RTX 5090 (sm_120), PyTorch 2.10.0+cu130, CUDA 13.0, Windows
  • Full Trellis2 image-to-3D pipeline including sparse convolution operations
  • Multiple successful runs

cuzelac and others added 2 commits March 8, 2026 01:43
The warp-level reduction loop in reduce_code_cuda_kernel reads from
shared memory at buf[threadIdx.x + cur_len] after a prior iteration
wrote to buf[threadIdx.x]. Without __syncwarp(), there is no guarantee
that the write is visible to other threads in the warp before the next
iteration reads it. While current NVIDIA hardware executes warps in
lockstep, this is undefined behavior per the CUDA programming model
and may break on future architectures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
CUDA 13.0 with MSVC 2025 (Visual Studio 18) fails to compile without
this flag, as nvcc only officially supports MSVC 2019-2022. The flag
allows compilation on newer toolchains without affecting runtime
behavior.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@cuzelac
Copy link
Copy Markdown
Contributor Author

cuzelac commented Mar 8, 2026

This may also address:

@JeffreyXiang
Copy link
Copy Markdown
Owner

Thanks! Merged

@JeffreyXiang JeffreyXiang merged commit 9f2f050 into JeffreyXiang:main Mar 11, 2026
@cuzelac
Copy link
Copy Markdown
Contributor Author

cuzelac commented Mar 11, 2026

Happy to contribute - thanks for your work on this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants