diff --git a/CMakeLists.txt b/CMakeLists.txt index aaf1bde..7d3f3a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,17 +106,41 @@ if(APPLE) target_compile_definitions(plume PRIVATE PLUME_APPLE_RETINA_ENABLED) endif() + # Find Metal Shader Converter runtime headers + # Prefer environment variable, fall back to default installation path + if(DEFINED ENV{METAL_SHADER_CONVERTER}) + set(METAL_SHADER_CONVERTER_PATH "$ENV{METAL_SHADER_CONVERTER}") + else() + set(METAL_SHADER_CONVERTER_PATH "/usr/local/lib/metal_irconverter_runtime") + endif() + + if(EXISTS "${METAL_SHADER_CONVERTER_PATH}/metal_irconverter_runtime") + set(METAL_SHADER_CONVERTER_INCLUDE "${METAL_SHADER_CONVERTER_PATH}") + message(STATUS "Plume - Metal Shader Converter runtime: ${METAL_SHADER_CONVERTER_INCLUDE}") + elseif(EXISTS "/usr/local/include/metal_irconverter_runtime") + set(METAL_SHADER_CONVERTER_INCLUDE "/usr/local/include") + message(STATUS "Plume - Metal Shader Converter runtime: ${METAL_SHADER_CONVERTER_INCLUDE}") + else() + message(WARNING "Plume - Metal Shader Converter runtime headers not found. Ray tracing will be disabled.") + set(METAL_SHADER_CONVERTER_INCLUDE "") + endif() + target_include_directories(plume PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/contrib/metal-cpp ) + if(METAL_SHADER_CONVERTER_INCLUDE) + target_include_directories(plume PRIVATE ${METAL_SHADER_CONVERTER_INCLUDE}) + target_compile_definitions(plume PRIVATE PLUME_METAL_RAYTRACING_ENABLED) + endif() + # Compile and embed internal Metal shaders include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/PlumeShaders.cmake) plume_build_file_to_c() file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/shaders") - _plume_compile_metal_impl(plume "${CMAKE_CURRENT_SOURCE_DIR}/shaders/plume_clear.metal" plume_clear) - _plume_compile_metal_impl(plume "${CMAKE_CURRENT_SOURCE_DIR}/shaders/plume_resolve.metal" plume_resolve) + _plume_compile_native_metal(plume "${CMAKE_CURRENT_SOURCE_DIR}/shaders/plume_clear.metal" plume_clear) + _plume_compile_native_metal(plume "${CMAKE_CURRENT_SOURCE_DIR}/shaders/plume_resolve.metal" plume_resolve) endif() # Add examples if requested diff --git a/cmake/PlumeShaders.cmake b/cmake/PlumeShaders.cmake index 46408d7..a24033a 100644 --- a/cmake/PlumeShaders.cmake +++ b/cmake/PlumeShaders.cmake @@ -1,6 +1,11 @@ # PlumeShaders.cmake # Public shader compilation API for Plume RHI # +# Architecture: +# Layer 1 (Primitives): Single-operation functions for each tool +# Layer 2 (Pipelines): Platform-aware composition of primitives +# Layer 3 (Public API): User-facing functions with nice defaults +# # Usage: # include(path/to/plume/cmake/PlumeShaders.cmake) # plume_shaders_init() @@ -8,32 +13,28 @@ # plume_compile_vertex_shader(my_target shaders/main.vert.hlsl mainVert VSMain) # plume_compile_pixel_shader(my_target shaders/main.frag.hlsl mainFrag PSMain) # plume_compile_compute_shader(my_target shaders/compute.hlsl computeShader CSMain) -# -# Advanced usage with extra options: -# plume_compile_pixel_shader(my_target shaders/main.frag.hlsl mainFrag PSMain -# EXTRA_ARGS -D MULTISAMPLING -O0 -# INCLUDE_DIRS ${CMAKE_SOURCE_DIR}/src -# SHADER_MODEL 6_3) -# -# # Spec constants mode (SPIRV + Metal only, no DXIL): -# plume_compile_pixel_shader(my_target shaders/main.frag.hlsl mainFrag PSMain SPEC_CONSTANTS) -# -# # Library shader (DXIL only, Windows): -# plume_compile_library_shader(my_target shaders/lib.hlsl libShader) -# -# Bring your own DXC/SPIRV-Cross (set before calling plume_shaders_init): -# set(PLUME_DXC_EXECUTABLE "/path/to/dxc") -# set(PLUME_DXC_LIB_DIR "/path/to/lib") # macOS/Linux only +# plume_compile_rt_shader(my_target shaders/rt.hlsl rtShaders) # # Output: -# HLSL shaders compile to: +# Stage shaders compile to: # - SPIR-V (all platforms): {OUTPUT_NAME}BlobSPIRV in shaders/{OUTPUT_NAME}.hlsl.spirv.h # - DXIL (Windows only): {OUTPUT_NAME}BlobDXIL in shaders/{OUTPUT_NAME}.hlsl.dxil.h -# - Metal (Apple only): {OUTPUT_NAME}BlobMSL in shaders/{OUTPUT_NAME}.metal.h (via SPIR-V cross-compilation) +# - Metal (Apple only): {OUTPUT_NAME}BlobMSL in shaders/{OUTPUT_NAME}.hlsl.metal.h +# +# RT library shaders compile to: +# - Windows: {OUTPUT_NAME}BlobDXIL in shaders/{OUTPUT_NAME}.hlsl.dxil.h +# - Linux: {OUTPUT_NAME}BlobSPIRV in shaders/{OUTPUT_NAME}.hlsl.spirv.h +# - Apple: {OUTPUT_NAME}BlobMetalLib in shaders/{OUTPUT_NAME}.metallib.h include("${CMAKE_CURRENT_LIST_DIR}/modules/PlumeFileToC.cmake") include("${CMAKE_CURRENT_LIST_DIR}/modules/PlumeDXC.cmake") include("${CMAKE_CURRENT_LIST_DIR}/modules/PlumeSpirvCross.cmake") +include("${CMAKE_CURRENT_LIST_DIR}/modules/PlumeRootSignature.cmake") +include("${CMAKE_CURRENT_LIST_DIR}/modules/PlumeCombineRTMetallibs.cmake") + +# ============================================================================ +# Initialization +# ============================================================================ # Initialize shader compilation infrastructure # Call this once before using other plume_compile_* functions @@ -54,287 +55,564 @@ function(plume_shaders_init) plume_fetch_spirv_cross() endif() + # Build helper tools (Apple RT shaders) + if(APPLE) + plume_build_generate_root_signature() + plume_build_combine_rt_metallibs() + endif() + plume_build_file_to_c() # Create output directory file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/shaders") endfunction() -# Internal: Compile HLSL to a specific format (spirv or dxil) -# Optional args: INCLUDE_DIRS, EXTRA_ARGS, SHADER_MODEL, OUTPUT_DIR -function(_plume_compile_hlsl_impl TARGET_NAME SHADER_SOURCE SHADER_TYPE OUTPUT_NAME OUTPUT_FORMAT ENTRY_POINT) - # Parse optional arguments - cmake_parse_arguments(PARSE_ARGV 6 ARG "" "SHADER_MODEL;OUTPUT_DIR" "INCLUDE_DIRS;EXTRA_ARGS") +# ============================================================================ +# Layer 1: Primitives - Single-operation functions for each tool +# ============================================================================ + +# Run DXC to compile HLSL to SPIRV or DXIL +# Arguments: +# TARGET - CMake target (for dependencies) +# SOURCE - HLSL source file +# OUTPUT - Output binary file path +# PROFILE - Shader profile (vs_6_0, ps_6_0, lib_6_3, etc.) +# ENTRY_POINT - Entry point name (empty string for libraries) +# FORMAT - Output format: "spirv" or "dxil" +# Options: +# SPIRV_RT - Add SPIRV ray tracing extensions (vulkan1.1spirv1.4 + SPV_KHR_ray_tracing) +# INVERT_Y - Add -fvk-invert-y for vertex shaders +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +function(_plume_dxc TARGET SOURCE OUTPUT PROFILE ENTRY_POINT FORMAT) + cmake_parse_arguments(ARG "SPIRV_RT;INVERT_Y" "" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) plume_get_dxc_command(DXC_CMD) - if(OUTPUT_FORMAT STREQUAL "spirv") - set(OUTPUT_EXT "spv") - set(BLOB_SUFFIX "SPIRV") - set(FORMAT_FLAGS ${PLUME_DXC_SPV_OPTS}) - elseif(OUTPUT_FORMAT STREQUAL "dxil") - set(OUTPUT_EXT "dxil") - set(BLOB_SUFFIX "DXIL") + # Base format flags + if(FORMAT STREQUAL "spirv") + if(ARG_SPIRV_RT) + # SPIRV with ray tracing extensions + set(FORMAT_FLAGS "-spirv" "-fspv-target-env=vulkan1.1spirv1.4" + "-fspv-extension=SPV_KHR_ray_tracing" + "-fspv-extension=SPV_EXT_descriptor_indexing" + "-fvk-use-dx-layout") + else() + # Standard SPIRV + set(FORMAT_FLAGS ${PLUME_DXC_SPV_OPTS}) + endif() + elseif(FORMAT STREQUAL "dxil") set(FORMAT_FLAGS ${PLUME_DXC_DXIL_OPTS}) else() - message(FATAL_ERROR "Unknown output format: ${OUTPUT_FORMAT}") - endif() - - # Use custom output directory if provided - if(ARG_OUTPUT_DIR) - set(OUT_DIR "${ARG_OUTPUT_DIR}") - else() - set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") + message(FATAL_ERROR "_plume_dxc: Unknown format '${FORMAT}'. Use 'spirv' or 'dxil'.") endif() - file(MAKE_DIRECTORY "${OUT_DIR}") - set(SHADER_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.${OUTPUT_EXT}") - set(C_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.${OUTPUT_FORMAT}.c") - set(H_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.${OUTPUT_FORMAT}.h") - - # Use provided shader model or default to 6_0 - if(ARG_SHADER_MODEL) - set(SM_VERSION "${ARG_SHADER_MODEL}") - else() - set(SM_VERSION "6_0") - endif() - - # Determine shader profile and type-specific args - if(SHADER_TYPE STREQUAL "vertex") - set(SHADER_PROFILE "vs_${SM_VERSION}") - set(DXC_TYPE_ARGS "-fvk-invert-y") - elseif(SHADER_TYPE STREQUAL "pixel" OR SHADER_TYPE STREQUAL "fragment") - set(SHADER_PROFILE "ps_${SM_VERSION}") - set(DXC_TYPE_ARGS "") - elseif(SHADER_TYPE STREQUAL "compute") - set(SHADER_PROFILE "cs_${SM_VERSION}") - set(DXC_TYPE_ARGS "") - elseif(SHADER_TYPE STREQUAL "geometry") - set(SHADER_PROFILE "gs_${SM_VERSION}") - set(DXC_TYPE_ARGS "") - elseif(SHADER_TYPE STREQUAL "ray") - set(SHADER_PROFILE "lib_6_3") - set(DXC_TYPE_ARGS ${PLUME_DXC_RT_OPTS}) - elseif(SHADER_TYPE STREQUAL "library") - set(SHADER_PROFILE "lib_${SM_VERSION}") - set(DXC_TYPE_ARGS "-D;LIBRARY") - else() - message(FATAL_ERROR "Unknown shader type: ${SHADER_TYPE}. Use: vertex, pixel/fragment, compute, geometry, ray, or library") + # Type-specific flags + set(TYPE_FLAGS "") + if(ARG_INVERT_Y AND FORMAT STREQUAL "spirv") + list(APPEND TYPE_FLAGS "-fvk-invert-y") endif() - # Build include directory flags + # Include directories set(INCLUDE_FLAGS "") - foreach(INCLUDE_DIR ${ARG_INCLUDE_DIRS}) - list(APPEND INCLUDE_FLAGS "-I${INCLUDE_DIR}") + foreach(DIR ${ARG_INCLUDE_DIRS}) + list(APPEND INCLUDE_FLAGS "-I${DIR}") endforeach() - set(BLOB_NAME "${OUTPUT_NAME}Blob${BLOB_SUFFIX}") - - # Build entry point args (library shaders don't have entry points) + # Entry point args (libraries don't have entry points) if(ENTRY_POINT STREQUAL "") - set(ENTRY_POINT_ARGS "") + set(ENTRY_ARGS "") else() - set(ENTRY_POINT_ARGS "-E" "${ENTRY_POINT}") + set(ENTRY_ARGS "-E" "${ENTRY_POINT}") endif() - # Compile using DXC add_custom_command( - OUTPUT "${SHADER_OUTPUT}" - COMMAND ${DXC_CMD} ${PLUME_DXC_COMMON_OPTS} ${INCLUDE_FLAGS} ${ENTRY_POINT_ARGS} -T ${SHADER_PROFILE} - ${FORMAT_FLAGS} ${DXC_TYPE_ARGS} ${ARG_EXTRA_ARGS} -Fo "${SHADER_OUTPUT}" "${SHADER_SOURCE}" - DEPENDS "${SHADER_SOURCE}" - COMMENT "Compiling ${SHADER_TYPE} shader ${OUTPUT_NAME} to ${OUTPUT_FORMAT}" + OUTPUT "${OUTPUT}" + COMMAND ${DXC_CMD} + ${PLUME_DXC_COMMON_OPTS} + ${INCLUDE_FLAGS} + ${ENTRY_ARGS} + -T ${PROFILE} + ${FORMAT_FLAGS} + ${TYPE_FLAGS} + ${ARG_EXTRA_ARGS} + -Fo "${OUTPUT}" + "${SOURCE}" + DEPENDS "${SOURCE}" + COMMENT "DXC: ${SOURCE} -> ${FORMAT}" VERBATIM ) +endfunction() - # Generate C header +# Run spirv-cross to convert SPIRV to Metal source +# Arguments: +# TARGET - CMake target (for dependencies) +# INPUT - SPIRV binary file +# OUTPUT - Metal source file path +function(_plume_spirv_cross TARGET INPUT OUTPUT) add_custom_command( - OUTPUT "${C_OUTPUT}" "${H_OUTPUT}" - COMMAND plume_file_to_c "${SHADER_OUTPUT}" "${BLOB_NAME}" "${C_OUTPUT}" "${H_OUTPUT}" - DEPENDS "${SHADER_OUTPUT}" plume_file_to_c - COMMENT "Generating C header for ${OUTPUT_NAME} ${OUTPUT_FORMAT}" + OUTPUT "${OUTPUT}" + COMMAND plume_spirv_cross_msl "${INPUT}" "${OUTPUT}" + DEPENDS "${INPUT}" plume_spirv_cross_msl + COMMENT "SPIRV-Cross: ${INPUT} -> Metal" VERBATIM ) - - target_sources(${TARGET_NAME} PRIVATE "${C_OUTPUT}") - target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_BINARY_DIR}") endfunction() -# Internal: Compile SPIR-V to Metal via spirv-cross -# Optional args: OUTPUT_DIR -# Note: For HLSL sources, OUTPUT_NAME should include .hlsl suffix for proper naming -function(_plume_compile_spirv_to_metal_impl TARGET_NAME SPIRV_FILE OUTPUT_NAME) - cmake_parse_arguments(PARSE_ARGV 3 ARG "" "OUTPUT_DIR" "") - - # Use custom output directory if provided - if(ARG_OUTPUT_DIR) - set(OUT_DIR "${ARG_OUTPUT_DIR}") - else() - set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") - endif() - file(MAKE_DIRECTORY "${OUT_DIR}") - - # Use OUTPUT_NAME.hlsl for naming to match RT64's expected paths - set(METAL_SOURCE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metal") - set(IR_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.ir") - set(METALLIB_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metallib") - set(C_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metal.c") - set(H_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metal.h") - - # Get deployment target for Metal compilation +# Run Metal compiler to create metallib from source +# Arguments: +# TARGET - CMake target (for dependencies) +# INPUT - Metal source file +# OUTPUT - Metallib output file +function(_plume_metal_compile TARGET INPUT OUTPUT) + get_filename_component(OUTPUT_DIR "${OUTPUT}" DIRECTORY) + get_filename_component(OUTPUT_NAME "${OUTPUT}" NAME_WE) + set(IR_FILE "${OUTPUT_DIR}/${OUTPUT_NAME}.ir") + + # Get deployment target if(CMAKE_OSX_DEPLOYMENT_TARGET) - set(METAL_VERSION_FLAG "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") + set(VERSION_FLAG "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") else() - set(METAL_VERSION_FLAG "") + set(VERSION_FLAG "") endif() - # Convert SPIR-V to Metal source + # Compile to IR add_custom_command( - OUTPUT "${METAL_SOURCE}" - COMMAND plume_spirv_cross_msl "${SPIRV_FILE}" "${METAL_SOURCE}" - DEPENDS "${SPIRV_FILE}" plume_spirv_cross_msl - COMMENT "Converting ${OUTPUT_NAME} SPIR-V to Metal" + OUTPUT "${IR_FILE}" + COMMAND xcrun -sdk macosx metal ${VERSION_FLAG} -o "${IR_FILE}" -c "${INPUT}" + DEPENDS "${INPUT}" + COMMENT "Metal: ${INPUT} -> IR" VERBATIM ) - # Compile Metal to IR + # Link to metallib add_custom_command( - OUTPUT "${IR_OUTPUT}" - COMMAND xcrun -sdk macosx metal ${METAL_VERSION_FLAG} -o "${IR_OUTPUT}" -c "${METAL_SOURCE}" - DEPENDS "${METAL_SOURCE}" - COMMENT "Compiling Metal shader ${OUTPUT_NAME} to IR" + OUTPUT "${OUTPUT}" + COMMAND xcrun -sdk macosx metallib "${IR_FILE}" -o "${OUTPUT}" + DEPENDS "${IR_FILE}" + COMMENT "Metallib: ${IR_FILE} -> ${OUTPUT}" VERBATIM ) +endfunction() - # Link IR to metallib - add_custom_command( - OUTPUT "${METALLIB_OUTPUT}" - COMMAND xcrun -sdk macosx metallib "${IR_OUTPUT}" -o "${METALLIB_OUTPUT}" - DEPENDS "${IR_OUTPUT}" - COMMENT "Linking ${OUTPUT_NAME} to metallib" - VERBATIM +# Run metal-shaderconverter to convert DXIL to Metal (for RT shaders) +# Arguments: +# TARGET - CMake target (for dependencies) +# INPUT - DXIL binary file +# OUTPUT - Metallib output file +# ROOT_SIGNATURE - Root signature JSON file +# Options: +# SYNTHESIZE_DISPATCH - Generate indirect ray dispatch kernel +function(_plume_metal_shader_converter TARGET INPUT OUTPUT ROOT_SIGNATURE) + cmake_parse_arguments(ARG "SYNTHESIZE_DISPATCH" "" "" ${ARGN}) + + find_program(METAL_SHADER_CONVERTER metal-shaderconverter + PATHS /usr/local/bin ENV PATH + DOC "Apple Metal Shader Converter" ) + if(NOT METAL_SHADER_CONVERTER) + message(FATAL_ERROR "metal-shaderconverter not found. Install from: https://developer.apple.com/metal/shader-converter/") + endif() + + set(SYNTH_FLAGS "") + if(ARG_SYNTHESIZE_DISPATCH) + set(SYNTH_FLAGS "--synthesize-indirect-ray-dispatch" "--synthesize-indirect-intersection-function") + endif() - # Generate C header add_custom_command( - OUTPUT "${C_OUTPUT}" "${H_OUTPUT}" - COMMAND plume_file_to_c "${METALLIB_OUTPUT}" "${OUTPUT_NAME}BlobMSL" "${C_OUTPUT}" "${H_OUTPUT}" - DEPENDS "${METALLIB_OUTPUT}" plume_file_to_c - COMMENT "Generating C header for Metal shader ${OUTPUT_NAME}" + OUTPUT "${OUTPUT}" + COMMAND ${METAL_SHADER_CONVERTER} + "${INPUT}" + -o "${OUTPUT}" + --deployment-os=macOS + --minimum-gpu-family=Metal3 + --root-signature=${ROOT_SIGNATURE} + ${SYNTH_FLAGS} + DEPENDS "${INPUT}" "${ROOT_SIGNATURE}" + COMMENT "MetalShaderConverter: ${INPUT} -> ${OUTPUT}" VERBATIM ) - - target_sources(${TARGET_NAME} PRIVATE "${C_OUTPUT}") - target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_BINARY_DIR}") endfunction() -# Internal: Compile native Metal shader to metallib (for handwritten .metal files) -function(_plume_compile_metal_impl TARGET_NAME SHADER_SOURCE OUTPUT_NAME) - set(IR_OUTPUT "${CMAKE_BINARY_DIR}/shaders/${OUTPUT_NAME}.ir") - set(METALLIB_OUTPUT "${CMAKE_BINARY_DIR}/shaders/${OUTPUT_NAME}.metallib") - set(C_OUTPUT "${CMAKE_BINARY_DIR}/shaders/${OUTPUT_NAME}.metal.c") - set(H_OUTPUT "${CMAKE_BINARY_DIR}/shaders/${OUTPUT_NAME}.metal.h") +# Embed binary file as C header +# Arguments: +# TARGET - CMake target to add source to +# INPUT - Binary file to embed +# VAR_NAME - C variable name for the data +# C_OUTPUT - Output .c file path +# H_OUTPUT - Output .h file path +# Options: +# TEXT - Embed as text (char[]) instead of binary (uint8_t[]) +function(_plume_embed TARGET INPUT VAR_NAME C_OUTPUT H_OUTPUT) + cmake_parse_arguments(ARG "TEXT" "" "" ${ARGN}) - # Get deployment target for Metal compilation - if(CMAKE_OSX_DEPLOYMENT_TARGET) - set(METAL_VERSION_FLAG "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") + if(ARG_TEXT) + set(TEXT_FLAG "--text") else() - set(METAL_VERSION_FLAG "") + set(TEXT_FLAG "") endif() - # Compile Metal to IR add_custom_command( - OUTPUT "${IR_OUTPUT}" - COMMAND xcrun -sdk macosx metal ${METAL_VERSION_FLAG} -o "${IR_OUTPUT}" -c "${SHADER_SOURCE}" - DEPENDS "${SHADER_SOURCE}" - COMMENT "Compiling Metal shader ${OUTPUT_NAME} to IR" + OUTPUT "${C_OUTPUT}" "${H_OUTPUT}" + COMMAND plume_file_to_c "${INPUT}" "${VAR_NAME}" "${C_OUTPUT}" "${H_OUTPUT}" ${TEXT_FLAG} + DEPENDS "${INPUT}" plume_file_to_c + COMMENT "Embed: ${INPUT} -> ${VAR_NAME}" VERBATIM ) - # Link IR to metallib + target_sources(${TARGET} PRIVATE "${C_OUTPUT}") + target_include_directories(${TARGET} PRIVATE "${CMAKE_BINARY_DIR}") +endfunction() + +# Generate root signature JSON from HLSL shader reflection (for Metal RT) +# Arguments: +# TARGET - CMake target (for dependencies) +# SOURCE - HLSL source file +# OUTPUT - Root signature JSON output file +# Options: +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +function(_plume_generate_root_signature TARGET SOURCE OUTPUT) + cmake_parse_arguments(ARG "" "" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) + + plume_get_dxc_command(DXC_CMD) + + get_filename_component(OUTPUT_DIR "${OUTPUT}" DIRECTORY) + get_filename_component(OUTPUT_NAME "${OUTPUT}" NAME_WE) + set(REFLECTION_FILE "${OUTPUT_DIR}/${OUTPUT_NAME}_reflection.txt") + + # Include directories + set(INCLUDE_FLAGS "") + foreach(DIR ${ARG_INCLUDE_DIRS}) + list(APPEND INCLUDE_FLAGS "-I${DIR}") + endforeach() + add_custom_command( - OUTPUT "${METALLIB_OUTPUT}" - COMMAND xcrun -sdk macosx metallib "${IR_OUTPUT}" -o "${METALLIB_OUTPUT}" - DEPENDS "${IR_OUTPUT}" - COMMENT "Linking ${OUTPUT_NAME} to metallib" + OUTPUT "${OUTPUT}" + COMMAND ${DXC_CMD} + ${PLUME_DXC_COMMON_OPTS} + ${INCLUDE_FLAGS} + -T lib_6_3 + -D RT_SHADER + ${ARG_EXTRA_ARGS} + -Fc "${REFLECTION_FILE}" + "${SOURCE}" + COMMAND plume_generate_root_signature "${REFLECTION_FILE}" "${OUTPUT}" + DEPENDS "${SOURCE}" plume_generate_root_signature + COMMENT "RootSig: ${SOURCE}" VERBATIM ) +endfunction() + +# ============================================================================ +# Layer 2: Pipelines - Platform-aware composition of primitives +# ============================================================================ + +# Get shader profile from stage type +function(_plume_get_profile TYPE SHADER_MODEL OUT_VAR) + if(TYPE STREQUAL "vertex") + set(${OUT_VAR} "vs_${SHADER_MODEL}" PARENT_SCOPE) + elseif(TYPE STREQUAL "pixel" OR TYPE STREQUAL "fragment") + set(${OUT_VAR} "ps_${SHADER_MODEL}" PARENT_SCOPE) + elseif(TYPE STREQUAL "compute") + set(${OUT_VAR} "cs_${SHADER_MODEL}" PARENT_SCOPE) + elseif(TYPE STREQUAL "geometry") + set(${OUT_VAR} "gs_${SHADER_MODEL}" PARENT_SCOPE) + else() + message(FATAL_ERROR "Unknown shader type: ${TYPE}") + endif() +endfunction() + +# Compile a stage shader (vertex, pixel, compute, geometry) to all platform formats +# This is the main pipeline for regular shaders. +# +# Arguments: +# TARGET - CMake target to add shader to +# SOURCE - HLSL source file +# TYPE - Shader type: vertex, pixel, compute, geometry +# OUTPUT_NAME - Base name for output files +# ENTRY_POINT - Shader entry point function +# Options: +# SPEC_CONSTANTS - Skip DXIL, only SPIRV + Metal (for specialization constants) +# SHADER_MODEL - Shader model version (default: 6_0) +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +# OUTPUT_DIR - Custom output directory +function(_plume_compile_stage_shader TARGET SOURCE TYPE OUTPUT_NAME ENTRY_POINT) + cmake_parse_arguments(ARG "SPEC_CONSTANTS" "SHADER_MODEL;OUTPUT_DIR" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) - # Generate C header + # Defaults + if(NOT ARG_SHADER_MODEL) + set(ARG_SHADER_MODEL "6_0") + endif() + if(ARG_OUTPUT_DIR) + set(OUT_DIR "${ARG_OUTPUT_DIR}") + else() + set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") + endif() + file(MAKE_DIRECTORY "${OUT_DIR}") + + # Get profile + _plume_get_profile(${TYPE} ${ARG_SHADER_MODEL} PROFILE) + + # Type-specific flags + set(DXC_OPTS "") + if(TYPE STREQUAL "vertex") + list(APPEND DXC_OPTS INVERT_Y) + endif() + if(ARG_INCLUDE_DIRS) + list(APPEND DXC_OPTS INCLUDE_DIRS ${ARG_INCLUDE_DIRS}) + endif() + if(ARG_EXTRA_ARGS) + list(APPEND DXC_OPTS EXTRA_ARGS ${ARG_EXTRA_ARGS}) + endif() + + # === SPIRV (always) === + set(SPIRV_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spv") + _plume_dxc(${TARGET} "${SOURCE}" "${SPIRV_FILE}" ${PROFILE} ${ENTRY_POINT} "spirv" ${DXC_OPTS}) + _plume_embed(${TARGET} "${SPIRV_FILE}" "${OUTPUT_NAME}BlobSPIRV" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spirv.c" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spirv.h") + + # === DXIL (Windows, unless SPEC_CONSTANTS) === + if(WIN32 AND NOT ARG_SPEC_CONSTANTS) + set(DXIL_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.dxil") + _plume_dxc(${TARGET} "${SOURCE}" "${DXIL_FILE}" ${PROFILE} ${ENTRY_POINT} "dxil" ${DXC_OPTS}) + _plume_embed(${TARGET} "${DXIL_FILE}" "${OUTPUT_NAME}BlobDXIL" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.dxil.c" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.dxil.h") + endif() + + # === Metal (Apple, via SPIRV-Cross) === + if(APPLE AND TARGET plume_spirv_cross_msl) + set(METAL_SOURCE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metal") + set(METALLIB_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metallib") + + _plume_spirv_cross(${TARGET} "${SPIRV_FILE}" "${METAL_SOURCE}") + _plume_metal_compile(${TARGET} "${METAL_SOURCE}" "${METALLIB_FILE}") + _plume_embed(${TARGET} "${METALLIB_FILE}" "${OUTPUT_NAME}BlobMSL" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metal.c" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.metal.h") + endif() +endfunction() + +# Compile a library shader (RT or general) for the current platform +# Libraries contain multiple exported functions without a single entry point. +# +# Arguments: +# TARGET - CMake target to add shader to +# SOURCE - HLSL source file +# OUTPUT_NAME - Base name for output files +# Options: +# RAYTRACING - Enable ray tracing extensions for SPIRV +# SHADER_MODEL - Shader model version (default: 6_3) +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +# OUTPUT_DIR - Custom output directory +function(_plume_compile_library_shader_impl TARGET SOURCE OUTPUT_NAME) + cmake_parse_arguments(ARG "RAYTRACING" "SHADER_MODEL;OUTPUT_DIR" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) + + # Defaults + if(NOT ARG_SHADER_MODEL) + set(ARG_SHADER_MODEL "6_3") + endif() + if(ARG_OUTPUT_DIR) + set(OUT_DIR "${ARG_OUTPUT_DIR}") + else() + set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") + endif() + file(MAKE_DIRECTORY "${OUT_DIR}") + + set(PROFILE "lib_${ARG_SHADER_MODEL}") + + # Common DXC options + set(DXC_OPTS "") + if(ARG_INCLUDE_DIRS) + list(APPEND DXC_OPTS INCLUDE_DIRS ${ARG_INCLUDE_DIRS}) + endif() + if(ARG_EXTRA_ARGS) + list(APPEND DXC_OPTS EXTRA_ARGS ${ARG_EXTRA_ARGS}) + endif() + + if(WIN32) + # === Windows: DXIL library (D3D12) === + set(DXIL_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.dxil") + _plume_dxc(${TARGET} "${SOURCE}" "${DXIL_FILE}" ${PROFILE} "" "dxil" ${DXC_OPTS}) + _plume_embed(${TARGET} "${DXIL_FILE}" "${OUTPUT_NAME}BlobDXIL" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.dxil.c" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.dxil.h") + + # === Windows: SPIRV library (Vulkan) === + set(SPIRV_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spv") + if(ARG_RAYTRACING) + list(APPEND DXC_OPTS SPIRV_RT) + endif() + _plume_dxc(${TARGET} "${SOURCE}" "${SPIRV_FILE}" ${PROFILE} "" "spirv" ${DXC_OPTS}) + _plume_embed(${TARGET} "${SPIRV_FILE}" "${OUTPUT_NAME}BlobSPIRV" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spirv.c" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spirv.h") + + elseif(APPLE) + # === Apple: DXIL -> Metal via metal-shaderconverter === + # This path is for RT shaders; regular libraries don't make sense on Metal + if(ARG_RAYTRACING) + _plume_compile_rt_metal(${TARGET} "${SOURCE}" ${OUTPUT_NAME} + OUTPUT_DIR "${OUT_DIR}" + INCLUDE_DIRS ${ARG_INCLUDE_DIRS} + EXTRA_ARGS ${ARG_EXTRA_ARGS}) + endif() + + else() + # === Linux: SPIRV with optional RT extensions === + set(SPIRV_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spv") + if(ARG_RAYTRACING) + list(APPEND DXC_OPTS SPIRV_RT) + endif() + _plume_dxc(${TARGET} "${SOURCE}" "${SPIRV_FILE}" ${PROFILE} "" "spirv" ${DXC_OPTS}) + _plume_embed(${TARGET} "${SPIRV_FILE}" "${OUTPUT_NAME}BlobSPIRV" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spirv.c" + "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spirv.h") + endif() +endfunction() + +# Compile RT shader to Metal (Apple only) +# Handles the complex Metal RT pipeline: DXIL -> root signature -> visible functions + dispatch +# +# Arguments: +# TARGET - CMake target to add shader to +# SOURCE - HLSL source file +# OUTPUT_NAME - Base name for output files +# Options: +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +# OUTPUT_DIR - Custom output directory +function(_plume_compile_rt_metal TARGET SOURCE OUTPUT_NAME) + cmake_parse_arguments(ARG "" "OUTPUT_DIR" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) + + if(NOT APPLE) + return() + endif() + + # Check for metal-shaderconverter + find_program(METAL_SHADER_CONVERTER metal-shaderconverter + PATHS /usr/local/bin ENV PATH + DOC "Apple Metal Shader Converter" + ) + if(NOT METAL_SHADER_CONVERTER) + message(WARNING "metal-shaderconverter not found. RT shaders will not be compiled for Metal. " + "Install from: https://developer.apple.com/metal/shader-converter/") + return() + endif() + + if(ARG_OUTPUT_DIR) + set(OUT_DIR "${ARG_OUTPUT_DIR}") + else() + set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") + endif() + file(MAKE_DIRECTORY "${OUT_DIR}") + + # Step 1: Compile HLSL to DXIL + set(DXIL_FILE "${OUT_DIR}/${OUTPUT_NAME}.dxil") + set(DXC_OPTS "") + if(ARG_INCLUDE_DIRS) + list(APPEND DXC_OPTS INCLUDE_DIRS ${ARG_INCLUDE_DIRS}) + endif() + if(ARG_EXTRA_ARGS) + list(APPEND DXC_OPTS EXTRA_ARGS "-D" "RT_SHADER" ${ARG_EXTRA_ARGS}) + else() + list(APPEND DXC_OPTS EXTRA_ARGS "-D" "RT_SHADER") + endif() + _plume_dxc(${TARGET} "${SOURCE}" "${DXIL_FILE}" "lib_6_3" "" "dxil" ${DXC_OPTS}) + + # Step 2: Generate root signature + set(ROOT_SIG_FILE "${OUT_DIR}/${OUTPUT_NAME}_root_signature.json") + _plume_generate_root_signature(${TARGET} "${SOURCE}" "${ROOT_SIG_FILE}" + INCLUDE_DIRS ${ARG_INCLUDE_DIRS} + EXTRA_ARGS ${ARG_EXTRA_ARGS}) + + # Step 3a: Convert to visible functions + set(VISIBLE_FUNCS_METALLIB "${OUT_DIR}/${OUTPUT_NAME}_functions.metallib") + _plume_metal_shader_converter(${TARGET} "${DXIL_FILE}" "${VISIBLE_FUNCS_METALLIB}" "${ROOT_SIG_FILE}") + + # Step 3b: Convert to dispatch kernel + set(DISPATCH_METALLIB "${OUT_DIR}/${OUTPUT_NAME}_dispatch.metallib") + _plume_metal_shader_converter(${TARGET} "${DXIL_FILE}" "${DISPATCH_METALLIB}" "${ROOT_SIG_FILE}" + SYNTHESIZE_DISPATCH) + + # Step 4: Combine both metallibs + set(COMBINED_METALLIB "${OUT_DIR}/${OUTPUT_NAME}.metallib") add_custom_command( - OUTPUT "${C_OUTPUT}" "${H_OUTPUT}" - COMMAND plume_file_to_c "${METALLIB_OUTPUT}" "${OUTPUT_NAME}BlobMSL" "${C_OUTPUT}" "${H_OUTPUT}" - DEPENDS "${METALLIB_OUTPUT}" plume_file_to_c - COMMENT "Generating C header for Metal shader ${OUTPUT_NAME}" + OUTPUT "${COMBINED_METALLIB}" + COMMAND plume_combine_rt_metallibs "${VISIBLE_FUNCS_METALLIB}" "${DISPATCH_METALLIB}" "${COMBINED_METALLIB}" "${ROOT_SIG_FILE}" + DEPENDS "${VISIBLE_FUNCS_METALLIB}" "${DISPATCH_METALLIB}" plume_combine_rt_metallibs + COMMENT "Combine RT metallibs: ${OUTPUT_NAME}" VERBATIM ) - target_sources(${TARGET_NAME} PRIVATE "${C_OUTPUT}") - target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_BINARY_DIR}") + # Step 5: Embed as C header + _plume_embed(${TARGET} "${COMBINED_METALLIB}" "${OUTPUT_NAME}BlobMetalLib" + "${OUT_DIR}/${OUTPUT_NAME}.metallib.c" + "${OUT_DIR}/${OUTPUT_NAME}.metallib.h") +endfunction() + +# Compile native Metal shader to metallib (for handwritten .metal files) +# +# Arguments: +# TARGET - CMake target to add shader to +# SOURCE - Metal source file +# OUTPUT_NAME - Base name for output files +# Options: +# OUTPUT_DIR - Custom output directory +function(_plume_compile_native_metal TARGET SOURCE OUTPUT_NAME) + cmake_parse_arguments(ARG "" "OUTPUT_DIR" "" ${ARGN}) + + if(NOT APPLE) + return() + endif() + + if(ARG_OUTPUT_DIR) + set(OUT_DIR "${ARG_OUTPUT_DIR}") + else() + set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") + endif() + file(MAKE_DIRECTORY "${OUT_DIR}") + + set(METALLIB_FILE "${OUT_DIR}/${OUTPUT_NAME}.metallib") + _plume_metal_compile(${TARGET} "${SOURCE}" "${METALLIB_FILE}") + _plume_embed(${TARGET} "${METALLIB_FILE}" "${OUTPUT_NAME}BlobMSL" + "${OUT_DIR}/${OUTPUT_NAME}.metal.c" + "${OUT_DIR}/${OUTPUT_NAME}.metal.h") endfunction() # ============================================================================ -# Public API +# Layer 3: Public API - User-facing functions # ============================================================================ # Compile a shader and add it to a target # Usage: plume_compile_shader(TARGET SOURCE TYPE OUTPUT_NAME ENTRY_POINT [options]) # TARGET - CMake target to add shader to # SOURCE - Path to shader source file (.hlsl or .metal) -# TYPE - Shader type: vertex, pixel, compute, geometry, ray, or library +# TYPE - Shader type: vertex, pixel, compute, geometry # OUTPUT_NAME - Base name for output files (e.g., "mainVert") # ENTRY_POINT - Shader entry point function name (e.g., "VSMain") # # Options: -# SPEC_CONSTANTS - Only compile SPIRV + Metal (no DXIL), for specialization constants +# SPEC_CONSTANTS - Only compile SPIRV + Metal (no DXIL) # SHADER_MODEL - Shader model version (default: 6_0) -# INCLUDE_DIRS - Additional include directories for DXC -# EXTRA_ARGS - Additional DXC arguments (e.g., -D MULTISAMPLING -O0) -# OUTPUT_DIR - Custom output directory (default: ${CMAKE_BINARY_DIR}/shaders) +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +# OUTPUT_DIR - Custom output directory function(plume_compile_shader TARGET_NAME SHADER_SOURCE SHADER_TYPE OUTPUT_NAME ENTRY_POINT) - # Parse optional arguments - cmake_parse_arguments(ARG "SPEC_CONSTANTS" "SHADER_MODEL;OUTPUT_DIR" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) - - get_filename_component(SHADER_EXT "${SHADER_SOURCE}" EXT) + get_filename_component(EXT "${SHADER_SOURCE}" EXT) - if(SHADER_EXT MATCHES "\\.metal$") + if(EXT MATCHES "\\.metal$") if(APPLE) - _plume_compile_metal_impl(${TARGET_NAME} "${SHADER_SOURCE}" ${OUTPUT_NAME}) - endif() - elseif(SHADER_EXT MATCHES "\\.hlsl$") - # Build optional args to pass to impl - set(IMPL_ARGS "") - if(ARG_SHADER_MODEL) - list(APPEND IMPL_ARGS SHADER_MODEL "${ARG_SHADER_MODEL}") - endif() - if(ARG_INCLUDE_DIRS) - list(APPEND IMPL_ARGS INCLUDE_DIRS ${ARG_INCLUDE_DIRS}) - endif() - if(ARG_EXTRA_ARGS) - list(APPEND IMPL_ARGS EXTRA_ARGS ${ARG_EXTRA_ARGS}) - endif() - if(ARG_OUTPUT_DIR) - list(APPEND IMPL_ARGS OUTPUT_DIR "${ARG_OUTPUT_DIR}") - set(OUT_DIR "${ARG_OUTPUT_DIR}") - else() - set(OUT_DIR "${CMAKE_BINARY_DIR}/shaders") - endif() - - # Always compile to SPIR-V - _plume_compile_hlsl_impl(${TARGET_NAME} "${SHADER_SOURCE}" ${SHADER_TYPE} ${OUTPUT_NAME} "spirv" ${ENTRY_POINT} ${IMPL_ARGS}) - - # Compile to DXIL on Windows (unless SPEC_CONSTANTS mode) - if(WIN32 AND NOT ARG_SPEC_CONSTANTS) - _plume_compile_hlsl_impl(${TARGET_NAME} "${SHADER_SOURCE}" ${SHADER_TYPE} ${OUTPUT_NAME} "dxil" ${ENTRY_POINT} ${IMPL_ARGS}) - endif() - - # Compile SPIR-V to Metal on Apple (if spirv-cross is available) - if(APPLE AND TARGET plume_spirv_cross_msl) - set(SPIRV_FILE "${OUT_DIR}/${OUTPUT_NAME}.hlsl.spv") - _plume_compile_spirv_to_metal_impl(${TARGET_NAME} "${SPIRV_FILE}" ${OUTPUT_NAME} OUTPUT_DIR "${OUT_DIR}") + _plume_compile_native_metal(${TARGET_NAME} "${SHADER_SOURCE}" ${OUTPUT_NAME} ${ARGN}) endif() + elseif(EXT MATCHES "\\.hlsl$") + _plume_compile_stage_shader(${TARGET_NAME} "${SHADER_SOURCE}" ${SHADER_TYPE} ${OUTPUT_NAME} ${ENTRY_POINT} ${ARGN}) else() - message(WARNING "Unsupported shader extension '${SHADER_EXT}' for ${SHADER_SOURCE}. Use .hlsl or .metal") + message(WARNING "Unsupported shader extension '${EXT}' for ${SHADER_SOURCE}. Use .hlsl or .metal") endif() endfunction() @@ -366,64 +644,66 @@ function(plume_compile_geometry_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME ENT plume_compile_shader(${TARGET_NAME} "${SHADER_SOURCE}" "geometry" ${OUTPUT_NAME} ${ENTRY_POINT} ${ARGN}) endfunction() -# Compile a ray tracing shader -# Usage: plume_compile_ray_shader(TARGET SOURCE OUTPUT_NAME ENTRY_POINT [options]) -# Options: SHADER_MODEL, INCLUDE_DIRS, EXTRA_ARGS (see plume_compile_shader) -function(plume_compile_ray_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME ENTRY_POINT) - plume_compile_shader(${TARGET_NAME} "${SHADER_SOURCE}" "ray" ${OUTPUT_NAME} ${ENTRY_POINT} ${ARGN}) -endfunction() - -# Compile a library shader (DXIL only, Windows) -# Usage: plume_compile_library_shader(TARGET SOURCE OUTPUT_NAME [options]) -# Options: SHADER_MODEL, INCLUDE_DIRS, EXTRA_ARGS, OUTPUT_DIR -function(plume_compile_library_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) - # Parse optional arguments - cmake_parse_arguments(ARG "" "SHADER_MODEL;OUTPUT_DIR" "INCLUDE_DIRS;EXTRA_ARGS" ${ARGN}) - - if(NOT WIN32) - return() - endif() - - # Build optional args to pass to impl - set(IMPL_ARGS "") - if(ARG_SHADER_MODEL) - list(APPEND IMPL_ARGS SHADER_MODEL "${ARG_SHADER_MODEL}") - else() - list(APPEND IMPL_ARGS SHADER_MODEL "6_3") # Library shaders default to 6_3 - endif() - if(ARG_INCLUDE_DIRS) - list(APPEND IMPL_ARGS INCLUDE_DIRS ${ARG_INCLUDE_DIRS}) - endif() - if(ARG_EXTRA_ARGS) - list(APPEND IMPL_ARGS EXTRA_ARGS ${ARG_EXTRA_ARGS}) - endif() - if(ARG_OUTPUT_DIR) - list(APPEND IMPL_ARGS OUTPUT_DIR "${ARG_OUTPUT_DIR}") - endif() - - # Library shaders don't have an entry point - use empty string - _plume_compile_hlsl_impl(${TARGET_NAME} "${SHADER_SOURCE}" "library" ${OUTPUT_NAME} "dxil" "" ${IMPL_ARGS}) -endfunction() - # Compile a native Metal shader (Apple only, no-op on other platforms) # Use this for handwritten .metal files, not for cross-compiled HLSL # Usage: plume_compile_metal_shader(TARGET SOURCE OUTPUT_NAME) function(plume_compile_metal_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) if(APPLE) - _plume_compile_metal_impl(${TARGET_NAME} "${SHADER_SOURCE}" ${OUTPUT_NAME}) + _plume_compile_native_metal(${TARGET_NAME} "${SHADER_SOURCE}" ${OUTPUT_NAME} ${ARGN}) endif() endfunction() +# Compile a ray tracing shader library +# Usage: plume_compile_rt_shader(TARGET SOURCE OUTPUT_NAME [options]) +# TARGET - CMake target to add shader to +# SOURCE - Path to HLSL shader source file +# OUTPUT_NAME - Base name for output files +# +# Options: +# SHADER_MODEL - Shader model version (default: 6_3) +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +# OUTPUT_DIR - Custom output directory +# Compile a ray tracing shader library +# +# Outputs: +# Windows: {OUTPUT_NAME}BlobDXIL in shaders/{OUTPUT_NAME}.hlsl.dxil.h +# {OUTPUT_NAME}BlobSPIRV in shaders/{OUTPUT_NAME}.hlsl.spirv.h +# Apple: {OUTPUT_NAME}BlobMetalLib in shaders/{OUTPUT_NAME}.metallib.h +# Linux: {OUTPUT_NAME}BlobSPIRV in shaders/{OUTPUT_NAME}.hlsl.spirv.h +function(plume_compile_rt_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) + _plume_compile_library_shader_impl(${TARGET_NAME} "${SHADER_SOURCE}" ${OUTPUT_NAME} RAYTRACING ${ARGN}) +endfunction() + +# Compile a general library shader (non-RT) +# Usage: plume_compile_library_shader(TARGET SOURCE OUTPUT_NAME [options]) +# TARGET - CMake target to add shader to +# SOURCE - Path to HLSL shader source file +# OUTPUT_NAME - Base name for output files +# +# Options: +# SHADER_MODEL - Shader model version (default: 6_3) +# INCLUDE_DIRS - Additional include directories +# EXTRA_ARGS - Additional DXC arguments +# OUTPUT_DIR - Custom output directory +# +# Output: +# Windows: {OUTPUT_NAME}BlobDXIL in shaders/{OUTPUT_NAME}.hlsl.dxil.h +# {OUTPUT_NAME}BlobSPIRV in shaders/{OUTPUT_NAME}.hlsl.spirv.h +# Linux: {OUTPUT_NAME}BlobSPIRV in shaders/{OUTPUT_NAME}.hlsl.spirv.h +function(plume_compile_library_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) + _plume_compile_library_shader_impl(${TARGET_NAME} "${SHADER_SOURCE}" ${OUTPUT_NAME} ${ARGN}) +endfunction() + # Preprocess a shader header file and embed it as text # Useful for runtime shader compilation where you need the preprocessed source -# Usage: plume_preprocess_shader(TARGET SOURCE OUTPUT_NAME [INCLUDE_DIRS dirs] [OUTPUT_DIR dir] [VAR_NAME name]) +# Usage: plume_preprocess_shader(TARGET SOURCE OUTPUT_NAME [options]) # VAR_NAME - Optional variable name for the embedded data (defaults to OUTPUT_NAME) function(plume_preprocess_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) cmake_parse_arguments(ARG "" "OUTPUT_DIR;VAR_NAME" "INCLUDE_DIRS" ${ARGN}) get_filename_component(SHADER_NAME "${SHADER_SOURCE}" NAME) - # Use custom output directory if provided if(ARG_OUTPUT_DIR) set(OUT_DIR "${ARG_OUTPUT_DIR}") else() @@ -431,7 +711,6 @@ function(plume_preprocess_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) endif() file(MAKE_DIRECTORY "${OUT_DIR}") - # Variable name for embedded data (defaults to OUTPUT_NAME) if(ARG_VAR_NAME) set(VAR_NAME "${ARG_VAR_NAME}") else() @@ -439,8 +718,6 @@ function(plume_preprocess_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) endif() set(PREPROCESSED_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.rw") - set(C_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.rw.c") - set(H_OUTPUT "${OUT_DIR}/${OUTPUT_NAME}.rw.h") # Build include directory flags set(INCLUDE_FLAGS "") @@ -477,15 +754,8 @@ function(plume_preprocess_shader TARGET_NAME SHADER_SOURCE OUTPUT_NAME) ) endif() - # Generate C header with text content (use --text for char type compatibility) - add_custom_command( - OUTPUT "${C_OUTPUT}" "${H_OUTPUT}" - COMMAND plume_file_to_c "${PREPROCESSED_OUTPUT}" "${VAR_NAME}Text" "${C_OUTPUT}" "${H_OUTPUT}" --text - DEPENDS "${PREPROCESSED_OUTPUT}" plume_file_to_c - COMMENT "Generating C header for preprocessed shader ${OUTPUT_NAME}" - VERBATIM - ) - - target_sources(${TARGET_NAME} PRIVATE "${C_OUTPUT}") - target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_BINARY_DIR}") + _plume_embed(${TARGET_NAME} "${PREPROCESSED_OUTPUT}" "${VAR_NAME}Text" + "${OUT_DIR}/${OUTPUT_NAME}.rw.c" + "${OUT_DIR}/${OUTPUT_NAME}.rw.h" + TEXT) endfunction() diff --git a/cmake/modules/PlumeCombineRTMetallibs.cmake b/cmake/modules/PlumeCombineRTMetallibs.cmake new file mode 100644 index 0000000..2ddd214 --- /dev/null +++ b/cmake/modules/PlumeCombineRTMetallibs.cmake @@ -0,0 +1,29 @@ +# PlumeCombineRTMetallibs.cmake +# Builds the combine_rt_metallibs tool for packaging Metal RT shader libraries + +# Build the combine_rt_metallibs tool for the host system +function(plume_build_combine_rt_metallibs) + if(TARGET plume_combine_rt_metallibs) + return() + endif() + + # Find the source file relative to this module + set(TOOL_SOURCE "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../tools/combine_rt_metallibs.cpp") + + if(NOT EXISTS "${TOOL_SOURCE}") + message(FATAL_ERROR "plume combine_rt_metallibs.cpp not found at ${TOOL_SOURCE}") + endif() + + add_executable(plume_combine_rt_metallibs ${TOOL_SOURCE}) + set_target_properties(plume_combine_rt_metallibs PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plume_tools" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + ) + + if(APPLE) + set_target_properties(plume_combine_rt_metallibs PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY "-" + ) + endif() +endfunction() diff --git a/cmake/modules/PlumeRootSignature.cmake b/cmake/modules/PlumeRootSignature.cmake new file mode 100644 index 0000000..e1f4a91 --- /dev/null +++ b/cmake/modules/PlumeRootSignature.cmake @@ -0,0 +1,29 @@ +# PlumeRootSignature.cmake +# Builds the generate_root_signature tool for creating Metal Shader Converter root signature JSON + +# Build the generate_root_signature tool for the host system +function(plume_build_generate_root_signature) + if(TARGET plume_generate_root_signature) + return() + endif() + + # Find the source file relative to this module + set(GEN_ROOT_SIG_SOURCE "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../tools/generate_root_signature.cpp") + + if(NOT EXISTS "${GEN_ROOT_SIG_SOURCE}") + message(FATAL_ERROR "plume generate_root_signature.cpp not found at ${GEN_ROOT_SIG_SOURCE}") + endif() + + add_executable(plume_generate_root_signature ${GEN_ROOT_SIG_SOURCE}) + set_target_properties(plume_generate_root_signature PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plume_tools" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + ) + + if(APPLE) + set_target_properties(plume_generate_root_signature PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY "-" + ) + endif() +endfunction() diff --git a/cmake/tools/combine_rt_metallibs.cpp b/cmake/tools/combine_rt_metallibs.cpp new file mode 100644 index 0000000..284a7eb --- /dev/null +++ b/cmake/tools/combine_rt_metallibs.cpp @@ -0,0 +1,121 @@ +// +// plume +// +// Copyright (c) 2024 renderbag and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file for details. +// + +// Combines two Metal RT shader libraries (visible functions + dispatch kernel) +// into a single blob with a header for easy loading. +// +// Format: +// [4 bytes] Magic: "PLRT" (Plume Ray Tracing) +// [4 bytes] Version: 1 +// [4 bytes] Functions metallib size (little-endian) +// [4 bytes] Dispatch metallib size (little-endian) +// [4 bytes] Root signature JSON size +// [N bytes] Functions metallib data +// [M bytes] Dispatch metallib data +// [K bytes] Root signature JSON + +#include +#include +#include +#include +#include +#include + +static const char MAGIC[4] = {'P', 'L', 'R', 'T'}; +std::vector read_file(const char* path) { + std::ifstream input_file{path, std::ios::binary}; + std::vector ret{}; + + if (!input_file.good()) { + return ret; + } + + input_file.seekg(0, std::ios::end); + ret.resize(input_file.tellg()); + + input_file.seekg(0, std::ios::beg); + input_file.read(ret.data(), ret.size()); + + return ret; +} + +void create_parent_if_needed(const char* path) { + std::filesystem::path parent_path = std::filesystem::path{path}.parent_path(); + if (!parent_path.empty()) { + std::filesystem::create_directories(parent_path); + } +} + +void write_uint32_le(std::ofstream& out, uint32_t value) { + char bytes[4]; + bytes[0] = static_cast(value & 0xFF); + bytes[1] = static_cast((value >> 8) & 0xFF); + bytes[2] = static_cast((value >> 16) & 0xFF); + bytes[3] = static_cast((value >> 24) & 0xFF); + out.write(bytes, 4); +} + +int main(int argc, const char** argv) { + if (argc != 5) { + printf("Usage: %s \n", argv[0]); + printf("\nCombines two Metal RT shader libraries into a single blob.\n"); + return EXIT_FAILURE; + } + + const char* functions_path = argv[1]; + const char* dispatch_path = argv[2]; + const char* output_path = argv[3]; + const char* root_signature_path = argv[4]; + + // Read both input files + std::vector functions_data = read_file(functions_path); + if (functions_data.empty()) { + fprintf(stderr, "Failed to read functions metallib: %s\n", functions_path); + return EXIT_FAILURE; + } + + std::vector dispatch_data = read_file(dispatch_path); + if (dispatch_data.empty()) { + fprintf(stderr, "Failed to read dispatch metallib: %s\n", dispatch_path); + return EXIT_FAILURE; + } + + std::vector root_signature_data = read_file(root_signature_path); + if (root_signature_data.empty()) { + fprintf(stderr, "Failed to read root signature JSON: %s\n", root_signature_path); + return EXIT_FAILURE; + } + + // Create output directory if needed + create_parent_if_needed(output_path); + + // Write combined output + std::ofstream output{output_path, std::ios::binary}; + if (!output.good()) { + fprintf(stderr, "Failed to create output file: %s\n", output_path); + return EXIT_FAILURE; + } + + // Write header + output.write(MAGIC, 4); + const uint32_t version = 1; + write_uint32_le(output, version); + write_uint32_le(output, static_cast(functions_data.size())); + write_uint32_le(output, static_cast(dispatch_data.size())); + write_uint32_le(output, static_cast(root_signature_data.size())); + + // Write data + output.write(functions_data.data(), functions_data.size()); + output.write(dispatch_data.data(), dispatch_data.size()); + output.write(root_signature_data.data(), root_signature_data.size()); + + printf("Combined RT metallib: %zu + %zu = %zu bytes\n", + functions_data.size(), dispatch_data.size(), + 20 + functions_data.size() + dispatch_data.size() + root_signature_data.size()); + + return EXIT_SUCCESS; +} diff --git a/cmake/tools/generate_root_signature.cpp b/cmake/tools/generate_root_signature.cpp new file mode 100644 index 0000000..f942ce1 --- /dev/null +++ b/cmake/tools/generate_root_signature.cpp @@ -0,0 +1,301 @@ +// +// plume +// +// Copyright (c) 2024 renderbag and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file for details. +// +// Generates a Metal Shader Converter root signature JSON from DXC disassembly output. +// Parses the "Resource Bindings:" section to extract shader resource bindings. +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct Resource { + std::string name; + std::string type; // cbuffer, texture, UAV, sampler + std::string dim; // NA, 2d, ras (raytracing acceleration structure), etc. + std::string bindType; // cb, t, u, s + int reg = 0; + int space = 0; + int count = 1; +}; + +// Trim whitespace from both ends +std::string trim(const std::string& s) { + size_t start = s.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) return ""; + size_t end = s.find_last_not_of(" \t\r\n"); + return s.substr(start, end - start + 1); +} + +// Split string by whitespace +std::vector splitWhitespace(const std::string& s) { + std::vector tokens; + std::istringstream iss(s); + std::string token; + while (iss >> token) { + tokens.push_back(token); + } + return tokens; +} + +// Parse HLSL bind like "cb0", "t0", "u0", "s0" +bool parseHlslBind(const std::string& bind, std::string& bindType, int& reg, int& space) { + std::string bindPart = bind; + std::string spacePart; + const size_t comma = bind.find(','); + if (comma != std::string::npos) { + bindPart = bind.substr(0, comma); + spacePart = bind.substr(comma + 1); + } + + size_t i = 0; + while (i < bindPart.size() && std::isalpha(static_cast(bindPart[i]))) { + i++; + } + if (i == 0 || i >= bindPart.size()) return false; + + bindType = bindPart.substr(0, i); + reg = std::atoi(bindPart.substr(i).c_str()); + + if (!spacePart.empty()) { + spacePart = trim(spacePart); + if (spacePart.rfind("space", 0) == 0) { + space = std::atoi(spacePart.substr(5).c_str()); + } + } + return true; +} + +std::vector parseReflection(const std::string& content) { + std::vector resources; + std::istringstream stream(content); + std::string line; + bool inBindings = false; + + while (std::getline(stream, line)) { + line = trim(line); + + if (line.find("Resource Bindings:") != std::string::npos) { + inBindings = true; + continue; + } + + if (!inBindings) continue; + + // Skip header lines + if (line.find("; Name") != std::string::npos || + line.find("; ----") != std::string::npos) { + continue; + } + + // End of bindings section (non-comment line or empty) + if (line.empty() || line[0] != ';') { + break; + } + + // Remove leading semicolon and trim + line = trim(line.substr(1)); + if (line.empty()) continue; + + // Parse: Name Type Format Dim ID HLSLBind [Space] Count + auto parts = splitWhitespace(line); + if (parts.size() < 7) continue; + + Resource res; + res.name = parts[0]; + res.type = parts[1]; + res.dim = parts[3]; + res.count = std::atoi(parts.back().c_str()); + size_t bindIndex = parts.size() - 2; + if (bindIndex > 0 && parts[bindIndex].rfind("space", 0) == 0) { + res.space = std::atoi(parts[bindIndex].substr(5).c_str()); + bindIndex--; + } + std::string hlslBind = parts[bindIndex]; + + if (!parseHlslBind(hlslBind, res.bindType, res.reg, res.space)) { + continue; + } + + resources.push_back(res); + } + + return resources; +} + +// Sort by type priority: UAV (u), SRV (t), CBV (cb/b), Sampler (s) +int typePriority(const std::string& bindType) { + if (bindType == "u") return 0; + if (bindType == "t") return 1; + if (bindType == "cb" || bindType == "b") return 2; + if (bindType == "s") return 3; + return 99; +} + +void writeRootSignature(std::ofstream& out, const std::vector& resources) { + // Sort resources by type priority + std::vector sorted = resources; + std::sort(sorted.begin(), sorted.end(), [](const Resource& a, const Resource& b) { + int pa = typePriority(a.bindType); + int pb = typePriority(b.bindType); + if (pa != pb) return pa < pb; + return a.reg < b.reg; + }); + + out << "{\n"; + out << " \"version\": \"IRRootSignatureVersion_1_1\",\n"; + out << " \"RootSignature\": {\n"; + out << " \"Flags\": \"IRRootSignatureFlagNone\",\n"; + out << " \"NumParameters\": " << sorted.size() << ",\n"; + out << " \"Parameters\": [\n"; + + for (size_t i = 0; i < sorted.size(); i++) { + const Resource& res = sorted[i]; + + if (res.bindType == "u") { + // UAV -> descriptor table + out << " {\n"; + out << " \"ParameterType\": \"IRRootParameterTypeDescriptorTable\",\n"; + out << " \"ShaderVisibility\": \"IRShaderVisibilityAll\",\n"; + out << " \"DescriptorTable\": {\n"; + out << " \"NumDescriptorRanges\": 1,\n"; + out << " \"DescriptorRanges\": [\n"; + out << " {\n"; + out << " \"RangeType\": \"IRDescriptorRangeTypeUAV\",\n"; + out << " \"NumDescriptors\": " << res.count << ",\n"; + out << " \"BaseShaderRegister\": " << res.reg << ",\n"; + out << " \"RegisterSpace\": " << res.space << ",\n"; + out << " \"OffsetInDescriptorsFromTableStart\": 0,\n"; + out << " \"Flags\": \"IRDescriptorRangeFlagNone\"\n"; + out << " }\n"; + out << " ]\n"; + out << " }\n"; + out << " }"; + } else if (res.bindType == "t") { + if (res.dim == "ras") { + // Acceleration structure -> root SRV + out << " {\n"; + out << " \"ParameterType\": \"IRRootParameterTypeSRV\",\n"; + out << " \"ShaderVisibility\": \"IRShaderVisibilityAll\",\n"; + out << " \"Descriptor\": {\n"; + out << " \"ShaderRegister\": " << res.reg << ",\n"; + out << " \"RegisterSpace\": " << res.space << ",\n"; + out << " \"Flags\": \"IRRootDescriptorFlagNone\"\n"; + out << " }\n"; + out << " }"; + } else { + // Regular texture -> descriptor table + out << " {\n"; + out << " \"ParameterType\": \"IRRootParameterTypeDescriptorTable\",\n"; + out << " \"ShaderVisibility\": \"IRShaderVisibilityAll\",\n"; + out << " \"DescriptorTable\": {\n"; + out << " \"NumDescriptorRanges\": 1,\n"; + out << " \"DescriptorRanges\": [\n"; + out << " {\n"; + out << " \"RangeType\": \"IRDescriptorRangeTypeSRV\",\n"; + out << " \"NumDescriptors\": " << res.count << ",\n"; + out << " \"BaseShaderRegister\": " << res.reg << ",\n"; + out << " \"RegisterSpace\": " << res.space << ",\n"; + out << " \"OffsetInDescriptorsFromTableStart\": 0,\n"; + out << " \"Flags\": \"IRDescriptorRangeFlagNone\"\n"; + out << " }\n"; + out << " ]\n"; + out << " }\n"; + out << " }"; + } + } else if (res.bindType == "cb" || res.bindType == "b") { + // Constant buffer -> root CBV + out << " {\n"; + out << " \"ParameterType\": \"IRRootParameterTypeCBV\",\n"; + out << " \"ShaderVisibility\": \"IRShaderVisibilityAll\",\n"; + out << " \"Descriptor\": {\n"; + out << " \"ShaderRegister\": " << res.reg << ",\n"; + out << " \"RegisterSpace\": " << res.space << ",\n"; + out << " \"Flags\": \"IRRootDescriptorFlagNone\"\n"; + out << " }\n"; + out << " }"; + } else if (res.bindType == "s") { + // Sampler -> descriptor table + out << " {\n"; + out << " \"ParameterType\": \"IRRootParameterTypeDescriptorTable\",\n"; + out << " \"ShaderVisibility\": \"IRShaderVisibilityAll\",\n"; + out << " \"DescriptorTable\": {\n"; + out << " \"NumDescriptorRanges\": 1,\n"; + out << " \"DescriptorRanges\": [\n"; + out << " {\n"; + out << " \"RangeType\": \"IRDescriptorRangeTypeSampler\",\n"; + out << " \"NumDescriptors\": " << res.count << ",\n"; + out << " \"BaseShaderRegister\": " << res.reg << ",\n"; + out << " \"RegisterSpace\": " << res.space << ",\n"; + out << " \"OffsetInDescriptorsFromTableStart\": 0,\n"; + out << " \"Flags\": \"IRDescriptorRangeFlagNone\"\n"; + out << " }\n"; + out << " ]\n"; + out << " }\n"; + out << " }"; + } + + if (i < sorted.size() - 1) { + out << ","; + } + out << "\n"; + } + + out << " ],\n"; + out << " \"NumStaticSamplers\": 0,\n"; + out << " \"StaticSamplers\": []\n"; + out << " }\n"; + out << "}\n"; +} + +int main(int argc, const char** argv) { + if (argc != 3) { + fprintf(stderr, "Usage: %s \n", argv[0]); + fprintf(stderr, "\nGenerates Metal Shader Converter root signature JSON from DXC disassembly.\n"); + return EXIT_FAILURE; + } + + const char* inputPath = argv[1]; + const char* outputPath = argv[2]; + + // Read input file + std::ifstream inputFile(inputPath); + if (!inputFile) { + fprintf(stderr, "Error: Cannot open input file: %s\n", inputPath); + return EXIT_FAILURE; + } + + std::stringstream buffer; + buffer << inputFile.rdbuf(); + std::string content = buffer.str(); + inputFile.close(); + + // Parse resources + std::vector resources = parseReflection(content); + + if (resources.empty()) { + fprintf(stderr, "Error: No resource bindings found in input\n"); + return EXIT_FAILURE; + } + + // Write output + std::ofstream outputFile(outputPath); + if (!outputFile) { + fprintf(stderr, "Error: Cannot open output file: %s\n", outputPath); + return EXIT_FAILURE; + } + + writeRootSignature(outputFile, resources); + outputFile.close(); + + return EXIT_SUCCESS; +} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e1c2a18..13422e9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,7 +1,9 @@ set(CMAKE_CXX_STANDARD 17) -# Enable SDL Vulkan integration -set(PLUME_SDL_VULKAN_ENABLED ON CACHE BOOL "Enable SDL Vulkan integration" FORCE) +# Enable SDL Vulkan integration (Linux only) +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + set(PLUME_SDL_VULKAN_ENABLED ON CACHE BOOL "Enable SDL Vulkan integration" FORCE) +endif() # Find SDL2 (required for examples) find_package(SDL2 REQUIRED) @@ -13,3 +15,15 @@ plume_shaders_init() # Add example subdirectories add_subdirectory(triangle) add_subdirectory(cube) + +# Raytracing example requires Metal Shader Converter on Apple platforms +if(APPLE) + if(METAL_SHADER_CONVERTER_INCLUDE) + add_subdirectory(raytracing) + else() + message(STATUS "Skipping raytracing example: Metal Shader Converter not found") + endif() +else() + # D3D12 raytracing doesn't need special tools + add_subdirectory(raytracing) +endif() diff --git a/examples/cube/main.cpp b/examples/cube/main.cpp index a2aeae8..f31e2b1 100644 --- a/examples/cube/main.cpp +++ b/examples/cube/main.cpp @@ -507,36 +507,28 @@ namespace plume { ctx.m_commandQueue->waitForCommandFence(ctx.m_fence.get()); } - void CubeExample(RenderInterface* renderInterface, const std::string& apiName) { - if (SDL_Init(SDL_INIT_VIDEO) != 0) { - fprintf(stderr, "SDL_Init Error: %s\n", SDL_GetError()); - return; - } - - uint32_t flags = SDL_WINDOW_RESIZABLE; -#if defined(__APPLE__) - flags |= SDL_WINDOW_METAL; -#endif - + void CubeExample(RenderInterface* renderInterface, SDL_Window* window, const std::string& apiName) { std::string windowTitle = "Plume Cube Texture Example (" + apiName + ")"; - SDL_Window* window = SDL_CreateWindow(windowTitle.c_str(), SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED, 1280, 720, flags); - if (!window) { - fprintf(stderr, "SDL_CreateWindow Error: %s\n", SDL_GetError()); - SDL_Quit(); - return; - } + SDL_SetWindowTitle(window, windowTitle.c_str()); + CubeContext ctx; +#if PLUME_SDL_VULKAN_ENABLED + createContext(ctx, renderInterface, window, apiName); +#elif defined(__linux__) SDL_SysWMinfo wmInfo; SDL_VERSION(&wmInfo.version); SDL_GetWindowWMInfo(window, &wmInfo); - - CubeContext ctx; -#if defined(__linux__) createContext(ctx, renderInterface, { wmInfo.info.x11.display, wmInfo.info.x11.window }, apiName); #elif defined(__APPLE__) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); SDL_MetalView view = SDL_Metal_CreateView(window); createContext(ctx, renderInterface, { wmInfo.info.cocoa.window, SDL_Metal_GetLayer(view) }, apiName); #elif defined(WIN32) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); createContext(ctx, renderInterface, { wmInfo.info.win.window }, apiName); #endif @@ -577,17 +569,19 @@ namespace plume { #if defined(__APPLE__) SDL_Metal_DestroyView(view); #endif - SDL_DestroyWindow(window); - SDL_Quit(); } } -std::unique_ptr CreateRenderInterface(std::string& apiName) { +std::unique_ptr CreateRenderInterface(SDL_Window* window, std::string& apiName) { const bool useVulkan = false; #if defined(_WIN32) if (useVulkan) { apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return plume::CreateVulkanInterface(window); +#else return plume::CreateVulkanInterface(); +#endif } else { apiName = "D3D12"; return plume::CreateD3D12Interface(); @@ -595,20 +589,57 @@ std::unique_ptr CreateRenderInterface(std::string& apiNa #elif defined(__APPLE__) if (useVulkan) { apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return plume::CreateVulkanInterface(window); +#else return plume::CreateVulkanInterface(); +#endif } else { apiName = "Metal"; return plume::CreateMetalInterface(); } #else apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return plume::CreateVulkanInterface(window); +#else return plume::CreateVulkanInterface(); #endif +#endif } int main(int argc, char* argv[]) { + if (SDL_Init(SDL_INIT_VIDEO) != 0) { + fprintf(stderr, "SDL_Init Error: %s\n", SDL_GetError()); + return 1; + } + + uint32_t flags = SDL_WINDOW_RESIZABLE; +#if PLUME_SDL_VULKAN_ENABLED + flags |= SDL_WINDOW_VULKAN; +#elif defined(__APPLE__) + flags |= SDL_WINDOW_METAL; +#endif + + SDL_Window* window = SDL_CreateWindow("Plume Cube Texture Example", SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED, 1280, 720, flags); + if (!window) { + fprintf(stderr, "SDL_CreateWindow Error: %s\n", SDL_GetError()); + SDL_Quit(); + return 1; + } + std::string apiName = "Unknown"; - auto renderInterface = CreateRenderInterface(apiName); - plume::CubeExample(renderInterface.get(), apiName); + auto renderInterface = CreateRenderInterface(window, apiName); + if (!renderInterface) { + fprintf(stderr, "Failed to create render interface\n"); + SDL_DestroyWindow(window); + SDL_Quit(); + return 1; + } + + plume::CubeExample(renderInterface.get(), window, apiName); + + SDL_DestroyWindow(window); + SDL_Quit(); return 0; } diff --git a/examples/cube/shaders/cube.frag.hlsl b/examples/cube/shaders/cube.frag.hlsl index c66f3db..182e489 100644 --- a/examples/cube/shaders/cube.frag.hlsl +++ b/examples/cube/shaders/cube.frag.hlsl @@ -2,7 +2,7 @@ // Samples from a cubemap texture [[vk::binding(0, 0)]] TextureCube cubeTexture : register(t0); -[[vk::binding(1, 0)]] SamplerState cubeSampler : register(s0); +[[vk::binding(1, 0)]] SamplerState cubeSampler : register(s1); struct PSInput { float4 position : SV_POSITION; diff --git a/examples/raytracing/CMakeLists.txt b/examples/raytracing/CMakeLists.txt new file mode 100644 index 0000000..c32e0ea --- /dev/null +++ b/examples/raytracing/CMakeLists.txt @@ -0,0 +1,41 @@ +# Ray tracing "Hello Triangle" example + +add_executable(plume_raytracing + main.cpp +) + +target_link_libraries(plume_raytracing PRIVATE plume ${SDL2_LIBRARIES}) + +# Platform-specific libraries +if(APPLE) + target_link_libraries(plume_raytracing PRIVATE "-framework Metal -framework QuartzCore -framework CoreGraphics -framework Foundation -framework IOKit") +elseif(WIN32) + target_link_libraries(plume_raytracing PRIVATE d3d12 dxgi) +elseif(CMAKE_SYSTEM_NAME MATCHES "Linux") + find_package(X11 REQUIRED) + target_include_directories(plume_raytracing PUBLIC ${X11_INCLUDE_DIR} ${X11_Xrandr_INCLUDE_PATH}) + target_link_libraries(plume_raytracing PRIVATE ${X11_LIBRARIES} ${X11_Xrandr_LIB}) +endif() + +target_include_directories(plume_raytracing PRIVATE + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/examples + ${SDL2_INCLUDE_DIRS} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/shaders +) + +# Compile ray tracing shaders for the current platform +# - D3D12 (Windows): HLSL -> DXIL library shader +# - Metal (Apple): HLSL -> DXIL -> metallib via Metal Shader Converter +# Produces two metallibs: visible functions + dispatch kernel +plume_compile_rt_shader(plume_raytracing + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/raytracing.hlsl" + rtShaders + SHADER_MODEL 6_3 +) + +# Set output directory +set_target_properties(plume_raytracing PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) diff --git a/examples/raytracing/main.cpp b/examples/raytracing/main.cpp new file mode 100644 index 0000000..ef75adf --- /dev/null +++ b/examples/raytracing/main.cpp @@ -0,0 +1,678 @@ +// Ray tracing "Hello Triangle" example +// Demonstrates basic ray tracing with Plume RHI + +#include "plume_render_interface.h" + +#include +#include + +#include +#include +#include +#include +#include + +// Shader blobs +#if defined(_WIN64) +#include "shaders/rtShaders.hlsl.dxil.h" +#include "shaders/rtShaders.hlsl.spirv.h" +#elif defined(__APPLE__) +#include "shaders/rtShaders.metallib.h" +#elif defined(__linux__) +#include "shaders/rtShaders.hlsl.spirv.h" +#endif + +namespace plume { + // Forward declarations for interface creation + extern std::unique_ptr CreateMetalInterface(); + extern std::unique_ptr CreateD3D12Interface(); + #if PLUME_SDL_VULKAN_ENABLED + extern std::unique_ptr CreateVulkanInterface(RenderWindow sdlWindow); + #else + extern std::unique_ptr CreateVulkanInterface(); + #endif +} + +using namespace plume; + +static const uint32_t BufferCount = 2; +static const RenderFormat SwapchainFormat = RenderFormat::B8G8R8A8_UNORM; + +// Camera constants matching shader +struct CameraConstants { + float viewInverse[16]; + float projInverse[16]; + uint32_t width; + uint32_t height; + uint32_t frameIndex; + uint32_t padding; +}; + +// Triangle vertex data (simple triangle in front of camera) +struct Vertex { + float position[3]; +}; + +// Simple identity matrix helper +void setIdentity(float* m) { + memset(m, 0, 16 * sizeof(float)); + m[0] = m[5] = m[10] = m[15] = 1.0f; +} + +// Inverse perspective projection matrix for ray generation +// Given NDC (x,y) and z=1 (looking into screen), produces view-space direction +void setPerspectiveInverse(float* m, float fovY, float aspect, float nearZ, float farZ) { + (void)nearZ; (void)farZ; // Not needed for ray generation inverse + + float tanHalfFov = tanf(fovY * 0.5f); + memset(m, 0, 16 * sizeof(float)); + + // For ray generation, we just need to scale NDC to view-space angles + // The shader does: target = projInverse * float4(ndc.x, ndc.y, 1, 1) + // then normalizes target.xyz / target.w + // + // We want (ndc.x, ndc.y, 1, 1) -> (view_x, view_y, view_z, 1) + // where view_x = ndc.x * tan(fov/2) * aspect + // view_y = ndc.y * tan(fov/2) + // view_z = 1 (forward into screen) + // view_w = 1 + m[0] = tanHalfFov * aspect; // scale x by aspect and fov + m[5] = tanHalfFov; // scale y by fov + m[10] = 1.0f; // z passes through + m[15] = 1.0f; // w = 1 +} + +// Simple look-at inverse (view inverse) matrix +void setViewInverse(float* m, float eyeX, float eyeY, float eyeZ) { + setIdentity(m); + m[12] = eyeX; + m[13] = eyeY; + m[14] = eyeZ; +} + +struct RTContext { + const RenderInterface* renderInterface = nullptr; + std::string apiName; + RenderWindow renderWindow = {}; + std::unique_ptr device; + std::unique_ptr commandQueue; + std::unique_ptr commandList; + std::unique_ptr swapChain; + std::unique_ptr acquireSemaphore; + std::vector> releaseSemaphores; + std::unique_ptr commandFence; + std::vector> framebuffers; + + // Ray tracing resources + bool rtSupported = false; + std::unique_ptr vertexBuffer; + std::unique_ptr indexBuffer; + std::unique_ptr blasBuffer; + std::unique_ptr blasScratchBuffer; + std::unique_ptr tlasBuffer; + std::unique_ptr tlasScratchBuffer; + std::unique_ptr instanceBuffer; + std::unique_ptr sbtBuffer; + std::unique_ptr blas; + std::unique_ptr tlas; + std::unique_ptr rtShader; + std::unique_ptr rtPipeline; + std::unique_ptr rtPipelineLayout; + std::unique_ptr outputTexture; + std::unique_ptr outputTextureView; + std::unique_ptr descriptorSet; + std::unique_ptr constantBuffer; + + RenderShaderBindingTableInfo sbtInfo; + uint32_t frameIndex = 0; +}; + +void createFramebuffers(RTContext& ctx) { + ctx.framebuffers.clear(); + for (uint32_t i = 0; i < ctx.swapChain->getTextureCount(); i++) { + const RenderTexture* colorAttachment = ctx.swapChain->getTexture(i); + RenderFramebufferDesc fbDesc; + fbDesc.colorAttachments = &colorAttachment; + fbDesc.colorAttachmentsCount = 1; + fbDesc.depthAttachment = nullptr; + auto framebuffer = ctx.device->createFramebuffer(fbDesc); + ctx.framebuffers.push_back(std::move(framebuffer)); + } +} + +void createOutputTexture(RTContext& ctx, uint32_t width, uint32_t height) { + RenderTextureDesc texDesc = RenderTextureDesc::Texture2D( + width, height, 1, RenderFormat::R8G8B8A8_UNORM, + RenderTextureFlag::STORAGE | RenderTextureFlag::UNORDERED_ACCESS + ); + + ctx.outputTexture = ctx.device->createTexture(texDesc); + ctx.outputTexture->setName("RT Output Texture"); + + RenderTextureViewDesc viewDesc; + viewDesc.format = RenderFormat::R8G8B8A8_UNORM; + viewDesc.dimension = RenderTextureViewDimension::TEXTURE_2D; + viewDesc.mipLevels = 1; + viewDesc.mipSlice = 0; + ctx.outputTextureView = ctx.outputTexture->createTextureView(viewDesc); +} + +void initializeRayTracing(RTContext& ctx) { + const auto& caps = ctx.device->getCapabilities(); + ctx.rtSupported = caps.raytracing; + + if (!ctx.rtSupported) { + std::cerr << "WARNING: Ray tracing is not supported on this device" << std::endl; + std::cerr << "The example will show a placeholder color instead" << std::endl; + return; + } + + // 1. Create geometry buffers - simple triangle + Vertex vertices[] = { + {{ 0.0f, 0.5f, 2.0f}}, // Top + {{-0.5f, -0.5f, 2.0f}}, // Bottom left + {{ 0.5f, -0.5f, 2.0f}} // Bottom right + }; + + uint32_t indices[] = {0, 1, 2}; + + // Create vertex buffer + RenderBufferDesc vbDesc = RenderBufferDesc::VertexBuffer(sizeof(vertices), RenderHeapType::UPLOAD); + vbDesc.flags |= RenderBufferFlag::ACCELERATION_STRUCTURE_INPUT; + ctx.vertexBuffer = ctx.device->createBuffer(vbDesc); + ctx.vertexBuffer->setName("Triangle Vertices"); + + void* vbData = ctx.vertexBuffer->map(); + memcpy(vbData, vertices, sizeof(vertices)); + ctx.vertexBuffer->unmap(); + + // Create index buffer + RenderBufferDesc ibDesc = RenderBufferDesc::IndexBuffer(sizeof(indices), RenderHeapType::UPLOAD); + ibDesc.flags |= RenderBufferFlag::ACCELERATION_STRUCTURE_INPUT; + ctx.indexBuffer = ctx.device->createBuffer(ibDesc); + ctx.indexBuffer->setName("Triangle Indices"); + + void* ibData = ctx.indexBuffer->map(); + memcpy(ibData, indices, sizeof(indices)); + ctx.indexBuffer->unmap(); + + // 2. Create BLAS + RenderBottomLevelASMesh mesh; + mesh.vertexBuffer = ctx.vertexBuffer->at(0); + mesh.vertexFormat = RenderFormat::R32G32B32_FLOAT; + mesh.vertexStride = sizeof(Vertex); + mesh.vertexCount = 3; + mesh.indexBuffer = ctx.indexBuffer->at(0); + mesh.indexFormat = RenderFormat::R32_UINT; + mesh.indexCount = 3; + mesh.isOpaque = true; + + RenderBottomLevelASBuildInfo blasBuildInfo; + ctx.device->setBottomLevelASBuildInfo(blasBuildInfo, &mesh, 1, false, true); + + // Create buffer to back the BLAS (required for Vulkan/D3D12, ignored by Metal) + RenderBufferDesc blasBufDesc = RenderBufferDesc::DefaultBuffer(blasBuildInfo.accelerationStructureSize); + blasBufDesc.flags = RenderBufferFlag::ACCELERATION_STRUCTURE; + ctx.blasBuffer = ctx.device->createBuffer(blasBufDesc); + ctx.blasBuffer->setName("BLAS Buffer"); + + RenderAccelerationStructureDesc blasDesc; + blasDesc.type = RenderAccelerationStructureType::BOTTOM_LEVEL; + blasDesc.buffer = ctx.blasBuffer->at(0); + blasDesc.size = blasBuildInfo.accelerationStructureSize; + ctx.blas = ctx.device->createAccelerationStructure(blasDesc); + + // Create scratch buffer for BLAS build + RenderBufferDesc scratchDesc = RenderBufferDesc::DefaultBuffer(blasBuildInfo.scratchSize); + scratchDesc.flags = RenderBufferFlag::STORAGE | RenderBufferFlag::ACCELERATION_STRUCTURE_SCRATCH; + ctx.blasScratchBuffer = ctx.device->createBuffer(scratchDesc); + ctx.blasScratchBuffer->setName("BLAS Scratch"); + + // Build BLAS + ctx.commandList->begin(); + ctx.commandList->buildBottomLevelAS(ctx.blas.get(), ctx.blasScratchBuffer->at(0), blasBuildInfo); + ctx.commandList->end(); + + const RenderCommandList* cmdList = ctx.commandList.get(); + ctx.commandQueue->executeCommandLists(&cmdList, 1, nullptr, 0, nullptr, 0, ctx.commandFence.get()); + ctx.commandQueue->waitForCommandFence(ctx.commandFence.get()); + + // 3. Create TLAS with single instance + // Identity transform (row-major 3x4) + RenderAffineTransform transform; + transform.m[0][0] = 1.0f; transform.m[0][1] = 0.0f; transform.m[0][2] = 0.0f; transform.m[0][3] = 0.0f; + transform.m[1][0] = 0.0f; transform.m[1][1] = 1.0f; transform.m[1][2] = 0.0f; transform.m[1][3] = 0.0f; + transform.m[2][0] = 0.0f; transform.m[2][1] = 0.0f; transform.m[2][2] = 1.0f; transform.m[2][3] = 0.0f; + + RenderTopLevelASInstance instance; + instance.bottomLevelAS = ctx.blas.get(); + instance.transform = transform; + instance.instanceID = 0; + instance.instanceMask = 0xFF; + instance.instanceContributionToHitGroupIndex = 0; + instance.cullDisable = false; + + RenderTopLevelASBuildInfo tlasBuildInfo; + ctx.device->setTopLevelASBuildInfo(tlasBuildInfo, &instance, 1, false, true); + + // Create buffer to back the TLAS (required for Vulkan/D3D12, ignored by Metal) + RenderBufferDesc tlasBufDesc = RenderBufferDesc::DefaultBuffer(tlasBuildInfo.accelerationStructureSize); + tlasBufDesc.flags = RenderBufferFlag::ACCELERATION_STRUCTURE; + ctx.tlasBuffer = ctx.device->createBuffer(tlasBufDesc); + ctx.tlasBuffer->setName("TLAS Buffer"); + + RenderAccelerationStructureDesc tlasDesc; + tlasDesc.type = RenderAccelerationStructureType::TOP_LEVEL; + tlasDesc.buffer = ctx.tlasBuffer->at(0); + tlasDesc.size = tlasBuildInfo.accelerationStructureSize; + ctx.tlas = ctx.device->createAccelerationStructure(tlasDesc); + + // Create instance buffer + RenderBufferDesc instanceBufDesc = RenderBufferDesc::UploadBuffer(tlasBuildInfo.instancesBufferData.size()); + instanceBufDesc.flags |= RenderBufferFlag::ACCELERATION_STRUCTURE_INPUT; + ctx.instanceBuffer = ctx.device->createBuffer(instanceBufDesc); + ctx.instanceBuffer->setName("Instance Buffer"); + + // Copy instance data + void* instanceData = ctx.instanceBuffer->map(); + memcpy(instanceData, tlasBuildInfo.instancesBufferData.data(), tlasBuildInfo.instancesBufferData.size()); + ctx.instanceBuffer->unmap(); + + // Create scratch buffer for TLAS build + RenderBufferDesc tlasScratchDesc = RenderBufferDesc::DefaultBuffer(tlasBuildInfo.scratchSize); + tlasScratchDesc.flags = RenderBufferFlag::STORAGE | RenderBufferFlag::ACCELERATION_STRUCTURE_SCRATCH; + ctx.tlasScratchBuffer = ctx.device->createBuffer(tlasScratchDesc); + ctx.tlasScratchBuffer->setName("TLAS Scratch"); + + // Build TLAS + ctx.commandList->begin(); + ctx.commandList->buildTopLevelAS(ctx.tlas.get(), ctx.tlasScratchBuffer->at(0), ctx.instanceBuffer->at(0), tlasBuildInfo); + ctx.commandList->end(); + + ctx.commandQueue->executeCommandLists(&cmdList, 1, nullptr, 0, nullptr, 0, ctx.commandFence.get()); + ctx.commandQueue->waitForCommandFence(ctx.commandFence.get()); + + // 4. Create output texture + createOutputTexture(ctx, ctx.swapChain->getWidth(), ctx.swapChain->getHeight()); + + // 5. Create constant buffer + RenderBufferDesc cbDesc = RenderBufferDesc::UploadBuffer(sizeof(CameraConstants)); + cbDesc.flags |= RenderBufferFlag::CONSTANT; + ctx.constantBuffer = ctx.device->createBuffer(cbDesc); + ctx.constantBuffer->setName("Camera Constants"); + + // 6. Create RT shader + RenderShaderFormat shaderFormat = ctx.renderInterface->getCapabilities().shaderFormat; + switch (shaderFormat) { +#ifdef __APPLE__ + case RenderShaderFormat::METAL: + ctx.rtShader = ctx.device->createShader(rtShadersBlobMetalLib, rtShadersBlobMetalLib_size, nullptr, shaderFormat); + break; +#endif +#ifdef _WIN64 + case RenderShaderFormat::DXIL: + ctx.rtShader = ctx.device->createShader(rtShadersBlobDXIL, rtShadersBlobDXIL_size, nullptr, shaderFormat); + break; +#endif +#if defined(_WIN64) || defined(__linux__) + case RenderShaderFormat::SPIRV: + ctx.rtShader = ctx.device->createShader(rtShadersBlobSPIRV, rtShadersBlobSPIRV_size, nullptr, shaderFormat); + break; +#endif + default: + break; + } + ctx.rtShader->setName("RT Shader Library"); + + // 7. Create pipeline layout with descriptors for: output texture (UAV), TLAS (SRV), constants (CBV) + std::vector ranges; + ranges.push_back(RenderDescriptorRange(RenderDescriptorRangeType::READ_WRITE_TEXTURE, 0, 1)); // u0: output texture + ranges.push_back(RenderDescriptorRange(RenderDescriptorRangeType::ACCELERATION_STRUCTURE, 1, 1)); // t0: TLAS + ranges.push_back(RenderDescriptorRange(RenderDescriptorRangeType::CONSTANT_BUFFER, 2, 1)); // b0: constants + + RenderDescriptorSetDesc descSetDesc(ranges.data(), static_cast(ranges.size())); + ctx.descriptorSet = ctx.device->createDescriptorSet(descSetDesc); + + // Bind resources to descriptor set + ctx.descriptorSet->setTexture(0, ctx.outputTexture.get(), RenderTextureLayout::GENERAL, ctx.outputTextureView.get()); + ctx.descriptorSet->setAccelerationStructure(1, ctx.tlas.get()); + ctx.descriptorSet->setBuffer(2, ctx.constantBuffer.get(), sizeof(CameraConstants)); + + // Create pipeline layout + RenderPipelineLayoutDesc layoutDesc; + layoutDesc.descriptorSetDescsCount = 1; + layoutDesc.descriptorSetDescs = &descSetDesc; + layoutDesc.isLocal = true; + ctx.rtPipelineLayout = ctx.device->createPipelineLayout(layoutDesc); + + // 8. Create RT pipeline + RenderRaytracingPipelineLibrarySymbol functionsSymbols[] = { + RenderRaytracingPipelineLibrarySymbol("RayGen", RenderRaytracingPipelineLibrarySymbolType::RAYGEN, "RayGen"), + RenderRaytracingPipelineLibrarySymbol("ClosestHit", RenderRaytracingPipelineLibrarySymbolType::CLOSEST_HIT, "ClosestHit"), + RenderRaytracingPipelineLibrarySymbol("Miss", RenderRaytracingPipelineLibrarySymbolType::MISS, "Miss") + }; + + // Set up shader library + RenderRaytracingPipelineLibrary shaderLibrary; + shaderLibrary.shader = ctx.rtShader.get(); + shaderLibrary.symbols = functionsSymbols; + shaderLibrary.symbolsCount = 3; + + RenderRaytracingPipelineHitGroup hitGroup; + hitGroup.hitGroupName = "HitGroup"; + hitGroup.closestHitName = "ClosestHit"; + hitGroup.anyHitName = nullptr; + hitGroup.intersectionName = nullptr; + + RenderRaytracingPipelineDesc rtPipelineDesc; + rtPipelineDesc.libraries = &shaderLibrary; + rtPipelineDesc.librariesCount = 1; + rtPipelineDesc.hitGroups = &hitGroup; + rtPipelineDesc.hitGroupsCount = 1; + rtPipelineDesc.pipelineLayout = ctx.rtPipelineLayout.get(); + rtPipelineDesc.maxPayloadSize = sizeof(float) * 4; // RayPayload: float3 color + uint depth + rtPipelineDesc.maxAttributeSize = sizeof(float) * 2; // Barycentrics + rtPipelineDesc.maxRecursionDepth = 1; + + ctx.rtPipeline = ctx.device->createRaytracingPipeline(rtPipelineDesc); + ctx.rtPipeline->setName("RT Pipeline"); + + // 9. Build Shader Binding Table + RenderPipelineProgram raygenProgram = ctx.rtPipeline->getProgram("RayGen"); + RenderPipelineProgram missProgram = ctx.rtPipeline->getProgram("Miss"); + RenderPipelineProgram hitGroupProgram = ctx.rtPipeline->getProgram("HitGroup"); + + RenderShaderBindingGroup raygenGroup(&raygenProgram, 1); + RenderShaderBindingGroup missGroup(&missProgram, 1); + RenderShaderBindingGroup hitGroupGroup(&hitGroupProgram, 1); + + RenderShaderBindingGroups sbtGroups(raygenGroup, missGroup, hitGroupGroup); + RenderDescriptorSet* descriptorSets[] = { ctx.descriptorSet.get() }; + ctx.device->setShaderBindingTableInfo(ctx.sbtInfo, sbtGroups, ctx.rtPipeline.get(), descriptorSets, 1); + + // Create SBT buffer and upload data + RenderBufferDesc sbtBufDesc = RenderBufferDesc::UploadBuffer(ctx.sbtInfo.tableBufferData.size()); + sbtBufDesc.flags |= RenderBufferFlag::SHADER_BINDING_TABLE; + ctx.sbtBuffer = ctx.device->createBuffer(sbtBufDesc); + ctx.sbtBuffer->setName("Shader Binding Table"); + + void* sbtData = ctx.sbtBuffer->map(); + memcpy(sbtData, ctx.sbtInfo.tableBufferData.data(), ctx.sbtInfo.tableBufferData.size()); + ctx.sbtBuffer->unmap(); +} + +void initializeRenderResources(RTContext& ctx, RenderInterface* renderInterface) { + ctx.device = renderInterface->createDevice(); + ctx.commandQueue = ctx.device->createCommandQueue(RenderCommandListType::DIRECT); + ctx.commandFence = ctx.device->createCommandFence(); + ctx.swapChain = ctx.commandQueue->createSwapChain(ctx.renderWindow, BufferCount, SwapchainFormat, 2); + ctx.swapChain->resize(); + ctx.commandList = ctx.commandQueue->createCommandList(); + ctx.acquireSemaphore = ctx.device->createCommandSemaphore(); + + createFramebuffers(ctx); + initializeRayTracing(ctx); +} + +void createContext(RTContext& ctx, RenderInterface* renderInterface, RenderWindow window, const std::string& apiName) { + ctx.renderInterface = renderInterface; + ctx.renderWindow = window; + ctx.apiName = apiName; + initializeRenderResources(ctx, const_cast(renderInterface)); +} + +void resize(RTContext& ctx, int width, int height) { + std::cout << "Resizing to " << width << "x" << height << std::endl; + if (ctx.swapChain) { + ctx.framebuffers.clear(); + bool resized = ctx.swapChain->resize(); + if (!resized) { + std::cerr << "Failed to resize swap chain" << std::endl; + return; + } + createFramebuffers(ctx); + + // Recreate output texture at new size + if (ctx.rtSupported) { + createOutputTexture(ctx, ctx.swapChain->getWidth(), ctx.swapChain->getHeight()); + ctx.descriptorSet->setTexture(0, ctx.outputTexture.get(), RenderTextureLayout::GENERAL, ctx.outputTextureView.get()); + } + } +} + +void updateCameraConstants(RTContext& ctx) { + const uint32_t width = ctx.swapChain->getWidth(); + const uint32_t height = ctx.swapChain->getHeight(); + + CameraConstants constants; + + // Simple camera at origin looking down +Z + setViewInverse(constants.viewInverse, 0.0f, 0.0f, 0.0f); + + // Perspective with 60 degree FOV + float aspect = static_cast(width) / static_cast(height); + setPerspectiveInverse(constants.projInverse, 3.14159f / 3.0f, aspect, 0.1f, 1000.0f); + + constants.width = width; + constants.height = height; + constants.frameIndex = ctx.frameIndex++; + constants.padding = 0; + + void* cbData = ctx.constantBuffer->map(); + memcpy(cbData, &constants, sizeof(constants)); + ctx.constantBuffer->unmap(); +} + +void render(RTContext& ctx) { + static int counter = 0; + if (counter++ % 60 == 0) { + std::cout << "Rendering frame " << counter << " using " << ctx.apiName << " backend" << std::endl; + } + + // Acquire the next swapchain image + uint32_t imageIndex = 0; + ctx.swapChain->acquireTexture(ctx.acquireSemaphore.get(), &imageIndex); + + ctx.commandList->begin(); + + RenderTexture* swapChainTexture = ctx.swapChain->getTexture(imageIndex); + + if (ctx.rtSupported) { + // Update camera constants + updateCameraConstants(ctx); + + // Transition output texture to general for compute write + ctx.commandList->barriers(RenderBarrierStage::COMPUTE, + RenderTextureBarrier(ctx.outputTexture.get(), RenderTextureLayout::GENERAL)); + + // Set up raytracing state + ctx.commandList->setRaytracingPipelineLayout(ctx.rtPipelineLayout.get()); + ctx.commandList->setPipeline(ctx.rtPipeline.get()); + ctx.commandList->setRaytracingDescriptorSet(ctx.descriptorSet.get(), 0); + + // Trace rays + const uint32_t width = ctx.swapChain->getWidth(); + const uint32_t height = ctx.swapChain->getHeight(); + ctx.commandList->traceRays(width, height, 1, ctx.sbtBuffer->at(0), ctx.sbtInfo.groups); + + // Transition textures for copy + ctx.commandList->barriers(RenderBarrierStage::COPY, + RenderTextureBarrier(ctx.outputTexture.get(), RenderTextureLayout::COPY_SOURCE)); + ctx.commandList->barriers(RenderBarrierStage::COPY, + RenderTextureBarrier(swapChainTexture, RenderTextureLayout::COPY_DEST)); + + // Copy output to swapchain + ctx.commandList->copyTexture(swapChainTexture, ctx.outputTexture.get()); + + // Transition swapchain for present + ctx.commandList->barriers(RenderBarrierStage::NONE, + RenderTextureBarrier(swapChainTexture, RenderTextureLayout::PRESENT)); + } else { + // Fallback: just clear to purple + ctx.commandList->barriers(RenderBarrierStage::GRAPHICS, + RenderTextureBarrier(swapChainTexture, RenderTextureLayout::COLOR_WRITE)); + + const RenderFramebuffer* framebuffer = ctx.framebuffers[imageIndex].get(); + ctx.commandList->setFramebuffer(framebuffer); + + const uint32_t width = ctx.swapChain->getWidth(); + const uint32_t height = ctx.swapChain->getHeight(); + const RenderViewport viewport(0.0f, 0.0f, float(width), float(height)); + const RenderRect scissor(0, 0, width, height); + + ctx.commandList->setViewports(viewport); + ctx.commandList->setScissors(scissor); + + RenderColor clearColor(0.4f, 0.2f, 0.6f, 1.0f); + ctx.commandList->clearColor(0, clearColor); + + ctx.commandList->barriers(RenderBarrierStage::NONE, + RenderTextureBarrier(swapChainTexture, RenderTextureLayout::PRESENT)); + } + + ctx.commandList->end(); + + // Create semaphores if needed + while (ctx.releaseSemaphores.size() < ctx.swapChain->getTextureCount()) { + ctx.releaseSemaphores.emplace_back(ctx.device->createCommandSemaphore()); + } + + const RenderCommandList* cmdList = ctx.commandList.get(); + RenderCommandSemaphore* waitSemaphore = ctx.acquireSemaphore.get(); + RenderCommandSemaphore* signalSemaphore = ctx.releaseSemaphores[imageIndex].get(); + + ctx.commandQueue->executeCommandLists(&cmdList, 1, &waitSemaphore, 1, &signalSemaphore, 1, ctx.commandFence.get()); + ctx.swapChain->present(imageIndex, &signalSemaphore, 1); + ctx.commandQueue->waitForCommandFence(ctx.commandFence.get()); +} + +// Platform-specific ray tracing API selection: +// - macOS: Metal only (MoltenVK doesn't support ray tracing extensions) +// - Windows: D3D12 by default, Vulkan optional (set useVulkan=true) +// - Linux: Vulkan only +std::unique_ptr CreateRenderInterface(SDL_Window* window, std::string& apiName) { + const bool useVulkan = false; +#if defined(__APPLE__) + // macOS: Metal only (MVK doesn't support ray tracing) + apiName = "Metal"; + return CreateMetalInterface(); +#elif defined(_WIN32) + if (useVulkan) { + apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return CreateVulkanInterface(window); +#else + return CreateVulkanInterface(); +#endif + } else { + apiName = "D3D12"; + return CreateD3D12Interface(); + } +#else + apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return CreateVulkanInterface(window); +#else + return CreateVulkanInterface(); +#endif +#endif +} + +int main(int argc, char* argv[]) { + if (SDL_Init(SDL_INIT_VIDEO) != 0) { + std::cerr << "SDL_Init Error: " << SDL_GetError() << std::endl; + return 1; + } + + uint32_t flags = SDL_WINDOW_RESIZABLE; +#if PLUME_SDL_VULKAN_ENABLED + flags |= SDL_WINDOW_VULKAN; +#elif defined(__APPLE__) + flags |= SDL_WINDOW_METAL; +#endif + + SDL_Window* window = SDL_CreateWindow( + "Plume Ray Tracing - Hello Triangle", + SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED, + 1280, 720, flags + ); + + if (!window) { + std::cerr << "SDL_CreateWindow Error: " << SDL_GetError() << std::endl; + SDL_Quit(); + return 1; + } + + std::string apiName; + auto renderInterface = CreateRenderInterface(window, apiName); + if (!renderInterface) { + std::cerr << "Failed to create render interface" << std::endl; + SDL_DestroyWindow(window); + SDL_Quit(); + return 1; + } + + RTContext ctx; +#if PLUME_SDL_VULKAN_ENABLED + createContext(ctx, renderInterface.get(), window, apiName); +#elif defined(__linux__) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); + createContext(ctx, renderInterface.get(), { wmInfo.info.x11.display, wmInfo.info.x11.window }, apiName); +#elif defined(__APPLE__) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); + SDL_MetalView view = SDL_Metal_CreateView(window); + createContext(ctx, renderInterface.get(), { wmInfo.info.cocoa.window, SDL_Metal_GetLayer(view) }, apiName); +#elif defined(_WIN32) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); + createContext(ctx, renderInterface.get(), { wmInfo.info.win.window }, apiName); +#endif + + bool running = true; + while (running) { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + running = false; + break; + case SDL_KEYDOWN: + if (event.key.keysym.sym == SDLK_ESCAPE) + running = false; + break; + case SDL_WINDOWEVENT: + if (event.window.event == SDL_WINDOWEVENT_RESIZED) { + resize(ctx, event.window.data1, event.window.data2); + } + break; + } + } + render(ctx); + } + + // Cleanup: transition swapchain out of present state + uint32_t imageIndex = 0; + if (!ctx.swapChain->isEmpty() && ctx.swapChain->acquireTexture(ctx.acquireSemaphore.get(), &imageIndex)) { + RenderTexture* swapChainTexture = ctx.swapChain->getTexture(imageIndex); + ctx.commandList->begin(); + ctx.commandList->barriers(RenderBarrierStage::NONE, RenderTextureBarrier(swapChainTexture, RenderTextureLayout::COLOR_WRITE)); + ctx.commandList->end(); + const RenderCommandList* cmdList = ctx.commandList.get(); + RenderCommandSemaphore* waitSemaphore = ctx.acquireSemaphore.get(); + ctx.commandQueue->executeCommandLists(&cmdList, 1, &waitSemaphore, 1, nullptr, 0, ctx.commandFence.get()); + ctx.commandQueue->waitForCommandFence(ctx.commandFence.get()); + } + +#if defined(__APPLE__) + SDL_Metal_DestroyView(view); +#endif + SDL_DestroyWindow(window); + SDL_Quit(); + + return 0; +} diff --git a/examples/raytracing/shaders/raytracing.hlsl b/examples/raytracing/shaders/raytracing.hlsl new file mode 100644 index 0000000..bda467a --- /dev/null +++ b/examples/raytracing/shaders/raytracing.hlsl @@ -0,0 +1,109 @@ +// Ray tracing "Hello Triangle" shader +// Compiles to DXIL for D3D12 and Metal via metal-shaderconverter + +// Output texture +[[vk::binding(0)]] +[[vk::image_format("rgba8")]] +RWTexture2D outputTexture : register(u0); + +// Acceleration structure +[[vk::binding(1)]] +RaytracingAccelerationStructure scene : register(t1); + +// Camera constants +[[vk::binding(2)]] +cbuffer CameraConstants : register(b2) +{ + float4x4 viewInverse; + float4x4 projInverse; + uint width; + uint height; + uint frameIndex; + uint padding; +}; + +// Ray payload structure +struct RayPayload +{ + float3 color; + uint depth; +}; + +// Built-in triangle intersection attributes +struct TriangleAttributes +{ + float2 barycentrics; +}; + +// Generate a ray for a given pixel +inline void generateCameraRay(uint2 pixelCoord, out float3 origin, out float3 direction) +{ + float2 pixelCenter = float2(pixelCoord) + 0.5; + float2 uv = pixelCenter / float2(width, height); + float2 ndc = uv * 2.0 - 1.0; + ndc.y = -ndc.y; // Flip Y for Vulkan/Metal convention + + float4 target = mul(projInverse, float4(ndc.x, ndc.y, 1.0, 1.0)); + target.xyz /= target.w; + + origin = mul(viewInverse, float4(0, 0, 0, 1)).xyz; + direction = normalize(mul(viewInverse, float4(target.xyz, 0)).xyz); +} + +[shader("raygeneration")] +void RayGen() +{ + uint2 launchIndex = DispatchRaysIndex().xy; + uint2 launchDim = DispatchRaysDimensions().xy; + + float3 origin; + float3 direction; + generateCameraRay(launchIndex, origin, direction); + + RayDesc ray; + ray.Origin = origin; + ray.Direction = direction; + ray.TMin = 0.001; + ray.TMax = 10000.0; + + RayPayload payload; + payload.color = float3(0, 0, 0); + payload.depth = 0; + + TraceRay( + scene, // Acceleration structure + RAY_FLAG_NONE, // Ray flags + 0xFF, // Instance inclusion mask + 0, // Hit group index (SBT offset) + 0, // Hit group stride (multiplier) + 0, // Miss shader index + ray, + payload + ); + + outputTexture[launchIndex] = float4(payload.color, 1.0); +} + +[shader("closesthit")] +void ClosestHit(inout RayPayload payload, in TriangleAttributes attribs) +{ + // Compute barycentric coordinates for coloring + float3 barycentrics = float3( + 1.0 - attribs.barycentrics.x - attribs.barycentrics.y, + attribs.barycentrics.x, + attribs.barycentrics.y + ); + + // Color based on barycentrics (RGB triangle) + payload.color = barycentrics; +} + +[shader("miss")] +void Miss(inout RayPayload payload) +{ + // Background gradient (sky color) + float2 uv = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy); + float3 topColor = float3(0.5, 0.7, 1.0); // Light blue + float3 bottomColor = float3(1.0, 1.0, 1.0); // White + payload.color = lerp(bottomColor, topColor, uv.y); +} diff --git a/examples/triangle/main.cpp b/examples/triangle/main.cpp index 503a18a..5d573c7 100644 --- a/examples/triangle/main.cpp +++ b/examples/triangle/main.cpp @@ -293,36 +293,28 @@ namespace plume { ctx.m_commandQueue->waitForCommandFence(ctx.m_fence.get()); } - void RenderInterfaceTest(RenderInterface* renderInterface, const std::string &apiName) { - if (SDL_Init(SDL_INIT_VIDEO) != 0) { - fprintf(stderr, "SDL_Init Error: %s\n", SDL_GetError()); - return; - } - - uint32_t flags = SDL_WINDOW_RESIZABLE; -#if defined(__APPLE__) - flags |= SDL_WINDOW_METAL; -#endif - + void RenderInterfaceTest(RenderInterface* renderInterface, SDL_Window* window, const std::string &apiName) { std::string windowTitle = "Plume Example (" + apiName + ")"; - SDL_Window* window = SDL_CreateWindow(windowTitle.c_str(), SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED, 1280, 720, flags); - if (!window) { - fprintf(stderr, "SDL_CreateWindow Error: %s\n", SDL_GetError()); - SDL_Quit(); - return; - } + SDL_SetWindowTitle(window, windowTitle.c_str()); + TestContext ctx; +#if PLUME_SDL_VULKAN_ENABLED + createContext(ctx, renderInterface, window, apiName); +#elif defined(__linux__) SDL_SysWMinfo wmInfo; SDL_VERSION(&wmInfo.version); SDL_GetWindowWMInfo(window, &wmInfo); - - TestContext ctx; -#if defined(__linux__) createContext(ctx, renderInterface, { wmInfo.info.x11.display, wmInfo.info.x11.window }, apiName); #elif defined(__APPLE__) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); SDL_MetalView view = SDL_Metal_CreateView(window); createContext(ctx, renderInterface, { wmInfo.info.cocoa.window, SDL_Metal_GetLayer(view) }, apiName); #elif defined(WIN32) + SDL_SysWMinfo wmInfo; + SDL_VERSION(&wmInfo.version); + SDL_GetWindowWMInfo(window, &wmInfo); createContext(ctx, renderInterface, { wmInfo.info.win.window }, apiName); #endif @@ -365,17 +357,19 @@ namespace plume { #if defined(__APPLE__) SDL_Metal_DestroyView(view); #endif - SDL_DestroyWindow(window); - SDL_Quit(); } } -std::unique_ptr CreateRenderInterface(std::string &apiName) { +std::unique_ptr CreateRenderInterface(SDL_Window* window, std::string &apiName) { const bool useVulkan = false; #if defined(_WIN32) if (useVulkan) { apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return plume::CreateVulkanInterface(window); +#else return plume::CreateVulkanInterface(); +#endif } else { apiName = "D3D12"; @@ -384,7 +378,11 @@ std::unique_ptr CreateRenderInterface(std::string &apiNa #elif defined(__APPLE__) if (useVulkan) { apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return plume::CreateVulkanInterface(window); +#else return plume::CreateVulkanInterface(); +#endif } else { apiName = "Metal"; @@ -392,13 +390,46 @@ std::unique_ptr CreateRenderInterface(std::string &apiNa } #else apiName = "Vulkan"; +#if PLUME_SDL_VULKAN_ENABLED + return plume::CreateVulkanInterface(window); +#else return plume::CreateVulkanInterface(); #endif +#endif } int main(int argc, char* argv[]) { + if (SDL_Init(SDL_INIT_VIDEO) != 0) { + fprintf(stderr, "SDL_Init Error: %s\n", SDL_GetError()); + return 1; + } + + uint32_t flags = SDL_WINDOW_RESIZABLE; +#if PLUME_SDL_VULKAN_ENABLED + flags |= SDL_WINDOW_VULKAN; +#elif defined(__APPLE__) + flags |= SDL_WINDOW_METAL; +#endif + + SDL_Window* window = SDL_CreateWindow("Plume Example", SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED, 1280, 720, flags); + if (!window) { + fprintf(stderr, "SDL_CreateWindow Error: %s\n", SDL_GetError()); + SDL_Quit(); + return 1; + } + std::string apiName = "Unknown"; - auto renderInterface = CreateRenderInterface(apiName); - plume::RenderInterfaceTest(renderInterface.get(), apiName); + auto renderInterface = CreateRenderInterface(window, apiName); + if (!renderInterface) { + fprintf(stderr, "Failed to create render interface\n"); + SDL_DestroyWindow(window); + SDL_Quit(); + return 1; + } + + plume::RenderInterfaceTest(renderInterface.get(), window, apiName); + + SDL_DestroyWindow(window); + SDL_Quit(); return 0; } diff --git a/plume_d3d12.cpp b/plume_d3d12.cpp index db6cf33..4149503 100644 --- a/plume_d3d12.cpp +++ b/plume_d3d12.cpp @@ -2711,7 +2711,14 @@ namespace plume { D3D12_RESOURCE_DESC resourceDesc = {}; resourceDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; - resourceDesc.Width = desc.size; + + // Constant buffers must be aligned to D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT (256 bytes) + // to allow creating CBVs that cover the full aligned size. + if (desc.flags & RenderBufferFlag::CONSTANT) { + resourceDesc.Width = roundUp(desc.size, D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT); + } else { + resourceDesc.Width = desc.size; + } resourceDesc.Height = 1; resourceDesc.DepthOrArraySize = 1; resourceDesc.MipLevels = 1; @@ -3997,15 +4004,16 @@ namespace plume { D3D12_RAYTRACING_INSTANCE_DESC *instanceDescs = reinterpret_cast(buildInfo.instancesBufferData.data()); for (uint32_t i = 0; i < instanceCount; i++) { const RenderTopLevelASInstance &instance = instances[i]; - const D3D12Buffer *interfaceBottomLevelAS = static_cast(instance.bottomLevelAS.ref); - assert(interfaceBottomLevelAS != nullptr); + const D3D12AccelerationStructure *blasAS = static_cast(instance.bottomLevelAS); + assert(blasAS != nullptr); + assert(blasAS->buffer != nullptr); D3D12_RAYTRACING_INSTANCE_DESC &instanceDesc = instanceDescs[i]; instanceDesc.InstanceID = instance.instanceID; instanceDesc.InstanceMask = instance.instanceMask; instanceDesc.InstanceContributionToHitGroupIndex = instance.instanceContributionToHitGroupIndex; instanceDesc.Flags = instance.cullDisable ? D3D12_RAYTRACING_INSTANCE_FLAG_TRIANGLE_CULL_DISABLE : D3D12_RAYTRACING_INSTANCE_FLAG_NONE; - instanceDesc.AccelerationStructure = interfaceBottomLevelAS->d3d->GetGPUVirtualAddress() + instance.bottomLevelAS.offset; + instanceDesc.AccelerationStructure = blasAS->buffer->d3d->GetGPUVirtualAddress() + blasAS->offset; memcpy(instanceDesc.Transform, instance.transform.m, sizeof(instanceDesc.Transform)); } diff --git a/plume_metal.cpp b/plume_metal.cpp index dc54782..5c9b228 100644 --- a/plume_metal.cpp +++ b/plume_metal.cpp @@ -14,8 +14,20 @@ #include #include +// Metal Shader Converter runtime header for ray tracing support. +// IR_RUNTIME_METALCPP enables metal-cpp compatibility mode. +// IR_PRIVATE_IMPLEMENTATION generates the implementation (define once). +#ifdef PLUME_METAL_RAYTRACING_ENABLED +#define IR_RUNTIME_METALCPP +#define IR_PRIVATE_IMPLEMENTATION +#include +#endif + #include +#include +#include #include +#include #include "plume_metal.h" #include "shaders/plume_clear.metal.h" @@ -549,6 +561,29 @@ namespace plume { } } + // Maps RenderFormat to MTL::AttributeFormat for acceleration structure vertex data. + MTL::AttributeFormat mapAttributeFormat(RenderFormat format) { + switch (format) { + case RenderFormat::R32G32B32_FLOAT: + return MTL::AttributeFormatFloat3; + case RenderFormat::R32G32B32A32_FLOAT: + return MTL::AttributeFormatFloat4; + case RenderFormat::R32G32_FLOAT: + return MTL::AttributeFormatFloat2; + case RenderFormat::R16G16B16A16_FLOAT: + return MTL::AttributeFormatHalf4; + case RenderFormat::R16G16_FLOAT: + return MTL::AttributeFormatHalf2; + case RenderFormat::R16G16B16A16_SNORM: + return MTL::AttributeFormatShort4Normalized; + case RenderFormat::R16G16_SNORM: + return MTL::AttributeFormatShort2Normalized; + default: + assert(false && "Format is not supported as an acceleration structure attribute format."); + return MTL::AttributeFormatFloat3; + } + } + MTL::TextureType mapTextureType(RenderTextureDimension dimension, RenderSampleCounts sampleCount, uint32_t arraySize) { switch (dimension) { case RenderTextureDimension::TEXTURE_1D: @@ -973,6 +1008,7 @@ namespace plume { // Initialize binding vector with -1 (invalid index) bindingToIndex.resize(MAX_BINDING_NUMBER, -1); + bindingDescriptorIndexBase.resize(MAX_BINDING_NUMBER, -1); // Pre-allocate vectors with known size const uint32_t totalDescriptors = desc.descriptorRangesCount + (desc.lastRangeIsBoundless ? desc.boundlessRangeSize : 0); @@ -988,6 +1024,7 @@ namespace plume { descriptorIndexBases.resize(descriptorIndexBases.size() + range.count, indexBase); descriptorBindingIndices.resize(descriptorBindingIndices.size() + range.count, range.binding); + bindingDescriptorIndexBase[range.binding] = static_cast(indexBase); } // Sort ranges by binding due to how spirv-cross orders them @@ -1117,6 +1154,13 @@ namespace plume { this->desc = desc; this->device = device; + // Metal manages acceleration structure storage internally, so buffers with + // ACCELERATION_STRUCTURE flag don't need actual GPU memory allocation. + if (desc.flags & RenderBufferFlag::ACCELERATION_STRUCTURE) { + this->mtl = nullptr; + return; + } + this->mtl = device->mtl->newBuffer(desc.size, mapResourceOption(desc.heapType)); if (desc.flags & RenderBufferFlag::DEVICE_ADDRESSABLE) { @@ -1132,6 +1176,10 @@ namespace plume { } MetalBuffer::~MetalBuffer() { + if (mtl == nullptr) { + return; + } + if (desc.flags & RenderBufferFlag::DEVICE_ADDRESSABLE) { std::lock_guard lock(device->gpuAddressableResourcesMutex); if (device->gpuAddressableResidencySet != nullptr) { @@ -1284,13 +1332,26 @@ namespace plume { MetalAccelerationStructure::MetalAccelerationStructure(MetalDevice *device, const RenderAccelerationStructureDesc &desc) { assert(device != nullptr); - assert(desc.buffer.ref != nullptr); + assert(desc.size > 0); this->device = device; this->type = desc.type; + this->size = desc.size; + + // Metal creates the acceleration structure internally with the specified size. + // Unlike Vulkan/D3D12, Metal does not use a user-provided buffer for AS storage. + mtl = device->mtl->newAccelerationStructure(desc.size); + if (mtl == nullptr) { + fprintf(stderr, "Failed to create Metal acceleration structure.\n"); + } } - MetalAccelerationStructure::~MetalAccelerationStructure() { } + MetalAccelerationStructure::~MetalAccelerationStructure() { + if (mtl != nullptr) { + mtl->release(); + mtl = nullptr; + } + } // MetalPool @@ -1314,6 +1375,30 @@ namespace plume { // MetalShader + // Combined RT metallib header format: + // [4 bytes] Magic: "PLRT" + // [4 bytes] Version: 1 + // [4 bytes] Functions metallib size (little-endian) + // [4 bytes] Dispatch metallib size (little-endian) + // [4 bytes] Root signature JSON size + // [N bytes] Functions metallib data + // [M bytes] Dispatch metallib data + // [K bytes] Root signature JSON + static const char PLRT_MAGIC[4] = {'P', 'L', 'R', 'T'}; + + static uint32_t readUint32LE(const uint8_t* data) { + return static_cast(data[0]) | + (static_cast(data[1]) << 8) | + (static_cast(data[2]) << 16) | + (static_cast(data[3]) << 24); + } + + static bool isCombinedRTMetallib(const void* data, uint64_t size) { + if (size < 16) return false; + const auto* bytes = static_cast(data); + return bytes[0] == 'P' && bytes[1] == 'L' && bytes[2] == 'R' && bytes[3] == 'T'; + } + MetalShader::MetalShader(const MetalDevice *device, const void *data, uint64_t size, const char *entryPointName, const RenderShaderFormat format) { assert(device != nullptr); assert(data != nullptr); @@ -1324,18 +1409,72 @@ namespace plume { this->functionName = (entryPointName != nullptr) ? NS::String::string(entryPointName, NS::UTF8StringEncoding) : MTLSTR(""); NS::Error *error = nullptr; - const dispatch_data_t dispatchData = dispatch_data_create(data, size, dispatch_get_main_queue(), ^{}); - library = device->mtl->newLibrary(dispatchData, &error); - if (error != nullptr) { - fprintf(stderr, "MTLDevice newLibraryWithSource: failed with error %s.\n", error->localizedDescription()->utf8String()); - return; + // Check if this is a combined RT metallib (functions + dispatch kernel) + if (isCombinedRTMetallib(data, size)) { + const auto* bytes = static_cast(data); + uint32_t version = readUint32LE(bytes + 4); + if (version != 1) { + fprintf(stderr, "Unsupported combined RT metallib version: %u\n", version); + return; + } + + const uint32_t functionsSize = readUint32LE(bytes + 8); + const uint32_t dispatchSize = readUint32LE(bytes + 12); + const uint32_t rootSignatureSize = readUint32LE(bytes + 16); + const uint32_t headerSize = 20; + + if (size < headerSize + functionsSize + dispatchSize + rootSignatureSize) { + fprintf(stderr, "Combined RT metallib is truncated.\n"); + return; + } + if (rootSignatureSize == 0) { + fprintf(stderr, "Combined RT metallib missing root signature JSON.\n"); + return; + } + + const void* functionsData = bytes + headerSize; + const void* dispatchData = bytes + headerSize + functionsSize; + const void* rootSignatureData = bytes + headerSize + functionsSize + dispatchSize; + + // Load functions library + dispatch_data_t functionsDispatchData = dispatch_data_create(functionsData, functionsSize, dispatch_get_main_queue(), ^{}); + library = device->mtl->newLibrary(functionsDispatchData, &error); + if (error != nullptr) { + fprintf(stderr, "Failed to load RT functions metallib: %s\n", error->localizedDescription()->utf8String()); + return; + } + + // Load dispatch library + error = nullptr; + dispatch_data_t dispatchDispatchData = dispatch_data_create(dispatchData, dispatchSize, dispatch_get_main_queue(), ^{}); + dispatchLibrary = device->mtl->newLibrary(dispatchDispatchData, &error); + if (error != nullptr) { + fprintf(stderr, "Failed to load RT dispatch metallib: %s\n", error->localizedDescription()->utf8String()); + return; + } + + rtRootSignatureJson.assign(static_cast(rootSignatureData), rootSignatureSize); + } else { + // Regular single metallib + const dispatch_data_t dispatchData = dispatch_data_create(data, size, dispatch_get_main_queue(), ^{}); + library = device->mtl->newLibrary(dispatchData, &error); + + if (error != nullptr) { + fprintf(stderr, "MTLDevice newLibraryWithSource: failed with error %s.\n", error->localizedDescription()->utf8String()); + return; + } } } MetalShader::~MetalShader() { functionName->release(); - library->release(); + if (library) { + library->release(); + } + if (dispatchLibrary) { + dispatchLibrary->release(); + } if (debugName) { debugName->release(); } @@ -1373,6 +1512,483 @@ namespace plume { return function; } + namespace { + enum class JsonTokenType { + End, + LBrace, + RBrace, + LBracket, + RBracket, + Colon, + Comma, + String, + Number + }; + + struct JsonToken { + JsonTokenType type = JsonTokenType::End; + std::string_view text; + }; + + class JsonTokenizer { + public: + explicit JsonTokenizer(std::string_view input) : input(input) {} + + JsonToken peek() { + if (!hasCached) { + cached = nextInternal(); + hasCached = true; + } + return cached; + } + + JsonToken next() { + if (hasCached) { + hasCached = false; + return cached; + } + return nextInternal(); + } + + private: + JsonToken nextInternal() { + skipWhitespace(); + if (pos >= input.size()) { + return {}; + } + + const char c = input[pos]; + switch (c) { + case '{': pos++; return {JsonTokenType::LBrace, {}}; + case '}': pos++; return {JsonTokenType::RBrace, {}}; + case '[': pos++; return {JsonTokenType::LBracket, {}}; + case ']': pos++; return {JsonTokenType::RBracket, {}}; + case ':': pos++; return {JsonTokenType::Colon, {}}; + case ',': pos++; return {JsonTokenType::Comma, {}}; + case '"': return parseString(); + default: + if (c == '-' || std::isdigit(static_cast(c))) { + return parseNumber(); + } + break; + } + + pos++; + return {}; + } + + JsonToken parseString() { + pos++; + const size_t start = pos; + while (pos < input.size()) { + const char c = input[pos]; + if (c == '\\') { + pos += 2; + continue; + } + if (c == '"') { + break; + } + pos++; + } + const size_t end = pos; + if (pos < input.size() && input[pos] == '"') { + pos++; + } + return {JsonTokenType::String, input.substr(start, end - start)}; + } + + JsonToken parseNumber() { + const size_t start = pos; + if (input[pos] == '-') { + pos++; + } + while (pos < input.size() && std::isdigit(static_cast(input[pos]))) { + pos++; + } + return {JsonTokenType::Number, input.substr(start, pos - start)}; + } + + void skipWhitespace() { + while (pos < input.size() && std::isspace(static_cast(input[pos]))) { + pos++; + } + } + + std::string_view input; + size_t pos = 0; + JsonToken cached{}; + bool hasCached = false; + }; + + bool skipValue(JsonTokenizer &tokenizer); + + bool expect(JsonTokenizer &tokenizer, JsonTokenType type) { + return tokenizer.next().type == type; + } + + bool skipObject(JsonTokenizer &tokenizer) { + while (true) { + JsonToken token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + return true; + } + if (token.type != JsonTokenType::String) { + return false; + } + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + if (!skipValue(tokenizer)) { + return false; + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool skipArray(JsonTokenizer &tokenizer) { + while (true) { + JsonToken token = tokenizer.peek(); + if (token.type == JsonTokenType::RBracket) { + tokenizer.next(); + return true; + } + if (!skipValue(tokenizer)) { + return false; + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool skipValue(JsonTokenizer &tokenizer) { + JsonToken token = tokenizer.next(); + switch (token.type) { + case JsonTokenType::LBrace: + return skipObject(tokenizer); + case JsonTokenType::LBracket: + return skipArray(tokenizer); + case JsonTokenType::String: + case JsonTokenType::Number: + return true; + default: + return false; + } + } + + MetalRTRootParameterType parseRootParameterType(std::string_view value) { + if (value == "IRRootParameterTypeDescriptorTable") { + return MetalRTRootParameterType::DescriptorTable; + } + if (value == "IRRootParameterTypeCBV") { + return MetalRTRootParameterType::RootCBV; + } + if (value == "IRRootParameterTypeSRV") { + return MetalRTRootParameterType::RootSRV; + } + if (value == "IRRootParameterTypeUAV") { + return MetalRTRootParameterType::RootUAV; + } + return MetalRTRootParameterType::Unknown; + } + + MetalRTDescriptorRangeType parseDescriptorRangeType(std::string_view value) { + if (value == "IRDescriptorRangeTypeSRV") { + return MetalRTDescriptorRangeType::SRV; + } + if (value == "IRDescriptorRangeTypeUAV") { + return MetalRTDescriptorRangeType::UAV; + } + if (value == "IRDescriptorRangeTypeSampler") { + return MetalRTDescriptorRangeType::Sampler; + } + return MetalRTDescriptorRangeType::Unknown; + } + + uint32_t parseUint32(std::string_view text) { + uint32_t value = 0; + for (char c : text) { + if (c < '0' || c > '9') { + break; + } + value = value * 10 + static_cast(c - '0'); + } + return value; + } + + bool parseDescriptorObject(JsonTokenizer &tokenizer, MetalRTRootParameter ¶m) { + if (!expect(tokenizer, JsonTokenType::LBrace)) { + return false; + } + while (true) { + JsonToken token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + return true; + } + if (token.type != JsonTokenType::String) { + return false; + } + std::string_view key = token.text; + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + JsonToken value = tokenizer.next(); + if (key == "ShaderRegister" && value.type == JsonTokenType::Number) { + param.shaderRegister = parseUint32(value.text); + } else if (key == "RegisterSpace" && value.type == JsonTokenType::Number) { + param.registerSpace = parseUint32(value.text); + } else { + if (value.type == JsonTokenType::LBrace) { + if (!skipObject(tokenizer)) return false; + } else if (value.type == JsonTokenType::LBracket) { + if (!skipArray(tokenizer)) return false; + } + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseDescriptorRangeObject(JsonTokenizer &tokenizer, MetalRTDescriptorRange &range) { + if (!expect(tokenizer, JsonTokenType::LBrace)) { + return false; + } + while (true) { + JsonToken token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + return true; + } + if (token.type != JsonTokenType::String) { + return false; + } + std::string_view key = token.text; + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + JsonToken value = tokenizer.next(); + if (key == "RangeType" && value.type == JsonTokenType::String) { + range.type = parseDescriptorRangeType(value.text); + } else if (key == "NumDescriptors" && value.type == JsonTokenType::Number) { + range.numDescriptors = parseUint32(value.text); + } else if (key == "BaseShaderRegister" && value.type == JsonTokenType::Number) { + range.baseRegister = parseUint32(value.text); + } else if (key == "RegisterSpace" && value.type == JsonTokenType::Number) { + range.registerSpace = parseUint32(value.text); + } else if (key == "OffsetInDescriptorsFromTableStart" && value.type == JsonTokenType::Number) { + range.offset = parseUint32(value.text); + } else { + if (value.type == JsonTokenType::LBrace) { + if (!skipObject(tokenizer)) return false; + } else if (value.type == JsonTokenType::LBracket) { + if (!skipArray(tokenizer)) return false; + } + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseDescriptorRanges(JsonTokenizer &tokenizer, std::vector &ranges) { + if (!expect(tokenizer, JsonTokenType::LBracket)) { + return false; + } + while (true) { + JsonToken token = tokenizer.peek(); + if (token.type == JsonTokenType::RBracket) { + tokenizer.next(); + return true; + } + MetalRTDescriptorRange range; + if (!parseDescriptorRangeObject(tokenizer, range)) { + return false; + } + ranges.push_back(range); + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseDescriptorTableObject(JsonTokenizer &tokenizer, MetalRTRootParameter ¶m) { + if (!expect(tokenizer, JsonTokenType::LBrace)) { + return false; + } + while (true) { + JsonToken token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + return true; + } + if (token.type != JsonTokenType::String) { + return false; + } + std::string_view key = token.text; + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + if (key == "DescriptorRanges") { + if (!parseDescriptorRanges(tokenizer, param.ranges)) { + return false; + } + } else { + if (!skipValue(tokenizer)) { + return false; + } + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseRootParameterObject(JsonTokenizer &tokenizer, MetalRTRootParameter ¶m) { + if (!expect(tokenizer, JsonTokenType::LBrace)) { + return false; + } + while (true) { + JsonToken token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + return true; + } + if (token.type != JsonTokenType::String) { + return false; + } + std::string_view key = token.text; + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + if (key == "ParameterType") { + JsonToken value = tokenizer.next(); + if (value.type == JsonTokenType::String) { + param.type = parseRootParameterType(value.text); + } else { + return false; + } + } else if (key == "DescriptorTable") { + if (!parseDescriptorTableObject(tokenizer, param)) { + return false; + } + } else if (key == "Descriptor") { + if (!parseDescriptorObject(tokenizer, param)) { + return false; + } + } else { + if (!skipValue(tokenizer)) { + return false; + } + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseParametersArray(JsonTokenizer &tokenizer, std::vector ¶ms) { + if (!expect(tokenizer, JsonTokenType::LBracket)) { + return false; + } + while (true) { + JsonToken token = tokenizer.peek(); + if (token.type == JsonTokenType::RBracket) { + tokenizer.next(); + return true; + } + MetalRTRootParameter param; + if (!parseRootParameterObject(tokenizer, param)) { + return false; + } + if (param.type != MetalRTRootParameterType::Unknown) { + params.push_back(std::move(param)); + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseRootSignatureObject(JsonTokenizer &tokenizer, std::vector ¶ms) { + if (!expect(tokenizer, JsonTokenType::LBrace)) { + return false; + } + while (true) { + JsonToken token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + return true; + } + if (token.type != JsonTokenType::String) { + return false; + } + std::string_view key = token.text; + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + if (key == "Parameters") { + if (!parseParametersArray(tokenizer, params)) { + return false; + } + } else { + if (!skipValue(tokenizer)) { + return false; + } + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + } + + bool parseRootSignatureJson(const std::string &json, std::vector ¶ms) { + if (json.empty()) { + return false; + } + JsonTokenizer tokenizer(json); + JsonToken token = tokenizer.next(); + if (token.type != JsonTokenType::LBrace) { + return false; + } + while (true) { + token = tokenizer.next(); + if (token.type == JsonTokenType::RBrace) { + break; + } + if (token.type != JsonTokenType::String) { + return false; + } + std::string_view key = token.text; + if (!expect(tokenizer, JsonTokenType::Colon)) { + return false; + } + if (key == "RootSignature") { + if (!parseRootSignatureObject(tokenizer, params)) { + return false; + } + } else { + if (!skipValue(tokenizer)) { + return false; + } + } + token = tokenizer.peek(); + if (token.type == JsonTokenType::Comma) { + tokenizer.next(); + } + } + return !params.empty(); + } + } + // MetalSampler MetalSampler::MetalSampler(const MetalDevice *device, const RenderSamplerDesc &desc) { @@ -1615,8 +2231,203 @@ namespace plume { // TODO: New - setting name happens at descriptor level - this would have to be reworked } - RenderPipelineProgram MetalGraphicsPipeline::getProgram(const std::string &name) const { - assert(false && "Graphics pipelines can't retrieve shader programs."); + RenderPipelineProgram MetalGraphicsPipeline::getProgram(const std::string &name) const { + assert(false && "Graphics pipelines can't retrieve shader programs."); + return RenderPipelineProgram(); + } + + // MetalRaytracingPipeline + // + // Metal RT pipeline creation with Metal Shader Converter output: + // + // Metal Shader Converter produces two metallibs: + // 1. Visible functions library: Contains RayGen, ClosestHit, Miss, etc. as visible functions + // 2. Dispatch kernel library: Contains RaygenIndirection compute kernel + // + // We need to: + // 1. Find the RaygenIndirection kernel function + // 2. Find all visible functions (RayGen, ClosestHit, Miss) + // 3. Link them together via MTL::LinkedFunctions + // 4. Create the compute pipeline + // 5. Create VFT and populate it with function handles from the pipeline + + MetalRaytracingPipeline::MetalRaytracingPipeline(MetalDevice *device, const RenderRaytracingPipelineDesc &desc, const RenderPipeline *previousPipeline) : MetalPipeline(device, Type::Raytracing) { + assert(device != nullptr); + assert(desc.pipelineLayout != nullptr); + assert(desc.librariesCount > 0); + + this->device = device; + this->pipelineLayout = static_cast(desc.pipelineLayout); + this->maxPayloadSize = desc.maxPayloadSize; + this->maxAttributeSize = desc.maxAttributeSize; + + NS::AutoreleasePool *releasePool = NS::AutoreleasePool::alloc()->init(); + + MTL::Function *raygenIndirectionFunction = nullptr; + std::vector visibleFunctions; + std::vector visibleFunctionNames; + std::string rootSignatureJson; + bool rootSignatureMismatch = false; + + // Scan all libraries to find the dispatch kernel and visible functions. + for (uint32_t i = 0; i < desc.librariesCount; i++) { + const RenderRaytracingPipelineLibrary &library = desc.libraries[i]; + assert(library.shader != nullptr); + + const MetalShader *shader = static_cast(library.shader); + if (!shader->rtRootSignatureJson.empty()) { + if (rootSignatureJson.empty()) { + rootSignatureJson = shader->rtRootSignatureJson; + } else if (rootSignatureJson != shader->rtRootSignatureJson) { + rootSignatureMismatch = true; + } + } + + // Get the RaygenIndirection kernel from the dispatch library. + if (raygenIndirectionFunction == nullptr && shader->dispatchLibrary != nullptr) { + NS::String *indirectionName = NS::String::string("RaygenIndirection", NS::UTF8StringEncoding); + raygenIndirectionFunction = shader->dispatchLibrary->newFunction(indirectionName); + } + + // Find visible functions from symbols. + for (uint32_t j = 0; j < library.symbolsCount; j++) { + const RenderRaytracingPipelineLibrarySymbol &symbol = library.symbols[j]; + const char *functionName = symbol.importName; + + NS::String *funcName = NS::String::string(functionName, NS::UTF8StringEncoding); + MTL::Function *func = shader->library->newFunction(funcName); + + if (func != nullptr) { + // Assign VFT index starting at 1 (index 0 is reserved/null function). + uint32_t vftIndex = static_cast(visibleFunctions.size()) + 1; + visibleFunctions.push_back(func); + visibleFunctionNames.push_back(functionName); + + // Map both import and export names to the VFT index. + nameProgramMap[std::string(functionName)] = vftIndex; + if (symbol.exportName != nullptr && symbol.exportName != symbol.importName) { + nameProgramMap[std::string(symbol.exportName)] = vftIndex; + } + } else { + fprintf(stderr, "Warning: Could not find visible function '%s' in shader library.\n", functionName); + } + } + } + + // Map hit groups to their closest hit shader's VFT index. + for (uint32_t i = 0; i < desc.hitGroupsCount; i++) { + const RenderRaytracingPipelineHitGroup &hitGroup = desc.hitGroups[i]; + if (hitGroup.closestHitName != nullptr) { + auto it = nameProgramMap.find(std::string(hitGroup.closestHitName)); + if (it != nameProgramMap.end()) { + nameProgramMap[std::string(hitGroup.hitGroupName)] = it->second; + } + } + } + + if (rootSignatureMismatch) { + fprintf(stderr, "Warning: Raytracing pipeline libraries have mismatched root signatures. Using the first one.\n"); + } + if (!rootSignatureJson.empty()) { + if (!parseRootSignatureJson(rootSignatureJson, rootSignatureParameters)) { + fprintf(stderr, "Warning: Failed to parse Metal root signature JSON for raytracing pipeline.\n"); + rootSignatureParameters.clear(); + } + } + + if (raygenIndirectionFunction == nullptr) { + fprintf(stderr, "Failed to find RaygenIndirection function. Make sure the shader was compiled with Metal Shader Converter using --synthesize-indirect-ray-dispatch.\n"); + for (auto func : visibleFunctions) { func->release(); } + releasePool->release(); + return; + } + + if (visibleFunctions.empty()) { + fprintf(stderr, "Warning: No visible functions found. Make sure visible functions library is provided.\n"); + } + + // Create linked functions descriptor with all visible functions. + MTL::LinkedFunctions *linkedFunctions = MTL::LinkedFunctions::alloc()->init(); + if (!visibleFunctions.empty()) { + NS::Array *functionsArray = NS::Array::array((NS::Object **)visibleFunctions.data(), visibleFunctions.size()); + linkedFunctions->setFunctions(functionsArray); + } + + // Create the compute pipeline descriptor. + MTL::ComputePipelineDescriptor *pipelineDesc = MTL::ComputePipelineDescriptor::alloc()->init(); + pipelineDesc->setComputeFunction(raygenIndirectionFunction); + pipelineDesc->setLinkedFunctions(linkedFunctions); + pipelineDesc->setMaxCallStackDepth(desc.maxRecursionDepth + 1); + + // Create the compute pipeline state. + NS::Error *error = nullptr; + computePipeline = device->mtl->newComputePipelineState(pipelineDesc, MTL::PipelineOptionNone, nullptr, &error); + + if (error != nullptr || computePipeline == nullptr) { + fprintf(stderr, "Failed to create raytracing compute pipeline: %s\n", + error ? error->localizedDescription()->utf8String() : "unknown error"); + for (auto func : visibleFunctions) { func->release(); } + raygenIndirectionFunction->release(); + linkedFunctions->release(); + pipelineDesc->release(); + releasePool->release(); + return; + } + + // Create and populate the Visible Function Table. + // VFT indices: 0 = null/reserved, 1+ = visible functions + if (!visibleFunctions.empty()) { + MTL::VisibleFunctionTableDescriptor *vftDesc = MTL::VisibleFunctionTableDescriptor::alloc()->init(); + vftDesc->setFunctionCount(visibleFunctions.size() + 1); // +1 for null at index 0 + + visibleFunctionTable = computePipeline->newVisibleFunctionTable(vftDesc); + if (visibleFunctionTable == nullptr) { + fprintf(stderr, "Failed to create visible function table.\n"); + } else { + // Populate VFT with function handles from the pipeline. + for (size_t i = 0; i < visibleFunctions.size(); i++) { + MTL::FunctionHandle *handle = computePipeline->functionHandle(visibleFunctions[i]); + if (handle != nullptr) { + // VFT index is i+1 (index 0 is null). + visibleFunctionTable->setFunction(handle, i + 1); + } else { + fprintf(stderr, "Warning: Could not get function handle for '%s'.\n", visibleFunctionNames[i].c_str()); + } + } + } + vftDesc->release(); + } + + // Release resources. + for (auto func : visibleFunctions) { func->release(); } + raygenIndirectionFunction->release(); + linkedFunctions->release(); + pipelineDesc->release(); + releasePool->release(); + } + + MetalRaytracingPipeline::~MetalRaytracingPipeline() { + if (intersectionFunctionTable != nullptr) { + intersectionFunctionTable->release(); + } + if (visibleFunctionTable != nullptr) { + visibleFunctionTable->release(); + } + if (computePipeline != nullptr) { + computePipeline->release(); + } + } + + void MetalRaytracingPipeline::setName(const std::string &name) { + // ComputePipelineState doesn't support setLabel after creation. + // The label must be set on the descriptor before pipeline creation. + } + + RenderPipelineProgram MetalRaytracingPipeline::getProgram(const std::string &name) const { + auto it = nameProgramMap.find(name); + if (it != nameProgramMap.end()) { + return RenderPipelineProgram(it->second); + } return RenderPipelineProgram(); } @@ -1687,11 +2498,15 @@ namespace plume { if (entry.resource != nullptr) { entry.resource->release(); } + if (entry.sampler != nullptr) { + entry.sampler->release(); + } } if (argumentBuffer.mtl != nullptr) { argumentBuffer.mtl->release(); } + } void MetalDescriptorSet::bindImmutableSamplers() const { @@ -1746,7 +2561,12 @@ namespace plume { offset = bufferStructuredView->firstElement * bufferStructuredView->structureByteStride; } - const BufferDescriptor descriptor = { .buffer = interfaceBuffer->mtl, .offset = offset }; + uint64_t bufferViewSize = bufferSize; + if (bufferViewSize == 0) { + bufferViewSize = interfaceBuffer->desc.size; + } + + const BufferDescriptor descriptor = { .buffer = interfaceBuffer->mtl, .offset = offset, .size = bufferViewSize }; setDescriptor(descriptorIndex, &descriptor); } } @@ -1769,6 +2589,9 @@ namespace plume { const TextureDescriptor descriptor = { .texture = interfaceTexture->mtl }; setDescriptor(descriptorIndex, &descriptor); } + + // Mark RT buffers dirty in case this is a UAV texture used for raytracing output. + rtBuffersDirty = true; } void MetalDescriptorSet::setSampler(const uint32_t descriptorIndex, const RenderSampler *sampler) { @@ -1783,7 +2606,63 @@ namespace plume { } void MetalDescriptorSet::setAccelerationStructure(uint32_t descriptorIndex, const RenderAccelerationStructure *accelerationStructure) { - // TODO: Unimplemented. + if (accelerationStructure == nullptr) { + // Clear the entry + if (descriptorIndex < resourceEntries.size()) { + if (resourceEntries[descriptorIndex].resource != nullptr) { + resourceEntries[descriptorIndex].resource->release(); + resourceEntries[descriptorIndex].resource = nullptr; + } + if (resourceEntries[descriptorIndex].sampler != nullptr) { + resourceEntries[descriptorIndex].sampler->release(); + resourceEntries[descriptorIndex].sampler = nullptr; + } + resourceEntries[descriptorIndex].type = RenderDescriptorRangeType::UNKNOWN; + resourceEntries[descriptorIndex].instanceCount = 0; + resourceEntries[descriptorIndex].instanceContributions.clear(); + resourceEntries[descriptorIndex].bufferOffset = 0; + resourceEntries[descriptorIndex].bufferSize = 0; + } + return; + } + + const MetalAccelerationStructure *metalAS = static_cast(accelerationStructure); + + // Store the acceleration structure in resourceEntries for later use during traceRays. + if (descriptorIndex < resourceEntries.size()) { + // Release old resource if any. + if (resourceEntries[descriptorIndex].resource != nullptr) { + if (residencySet != nullptr) { + std::lock_guard lock(residencySetWriteMutex); + residencySet->removeAllocation(resourceEntries[descriptorIndex].resource); + needsCommit = true; + } + resourceEntries[descriptorIndex].resource->release(); + } + if (resourceEntries[descriptorIndex].sampler != nullptr) { + resourceEntries[descriptorIndex].sampler->release(); + resourceEntries[descriptorIndex].sampler = nullptr; + } + + // Store as MTL::Resource (MTL::AccelerationStructure inherits from MTL::Resource). + resourceEntries[descriptorIndex].resource = metalAS->mtl; + resourceEntries[descriptorIndex].type = RenderDescriptorRangeType::ACCELERATION_STRUCTURE; + resourceEntries[descriptorIndex].instanceCount = metalAS->instanceCount; + resourceEntries[descriptorIndex].instanceContributions = metalAS->instanceContributions; + resourceEntries[descriptorIndex].bufferOffset = 0; + resourceEntries[descriptorIndex].bufferSize = 0; + metalAS->mtl->retain(); + + // Add to residency set if available. + if (residencySet != nullptr) { + std::lock_guard lock(residencySetWriteMutex); + residencySet->addAllocation(metalAS->mtl); + needsCommit = true; + } + + // Mark RT buffers dirty so TLAB is rebuilt on next traceRays. + rtBuffersDirty = true; + } } void MetalDescriptorSet::setDescriptor(const uint32_t descriptorIndex, const Descriptor *descriptor) { @@ -1794,19 +2673,27 @@ namespace plume { const auto &setLayoutBinding = setLayout->setBindings[indexBase]; const MTL::DataType dtype = mapDataType(setLayoutBinding.descriptorType); MTL::Resource *nativeResource = nullptr; + MTL::SamplerState *nativeSampler = nullptr; RenderDescriptorRangeType descriptorType = getDescriptorType(bindingIndex); - if (dtype != MTL::DataTypeSampler) { - if (resourceEntries[descriptorIndex].resource != nullptr) { - if (residencySet != nullptr) { - std::lock_guard lock(residencySetWriteMutex); - residencySet->removeAllocation(resourceEntries[descriptorIndex].resource); - needsCommit = true; - } - resourceEntries[descriptorIndex].resource->release(); - resourceEntries[descriptorIndex].resource = nullptr; + auto &entry = resourceEntries[descriptorIndex]; + if (entry.resource != nullptr) { + if (residencySet != nullptr) { + std::lock_guard lock(residencySetWriteMutex); + residencySet->removeAllocation(entry.resource); + needsCommit = true; } + entry.resource->release(); + entry.resource = nullptr; + } + if (entry.sampler != nullptr) { + entry.sampler->release(); + entry.sampler = nullptr; } + entry.bufferOffset = 0; + entry.bufferSize = 0; + entry.instanceCount = 0; + entry.instanceContributions.clear(); if (descriptor != nullptr) { const uint32_t argumentIndex = descriptorIndex - indexBase + bindingIndex; @@ -1839,6 +2726,8 @@ namespace plume { case MTL::DataTypePointer: { const BufferDescriptor *bufferDescriptor = static_cast(descriptor); nativeResource = bufferDescriptor->buffer; + entry.bufferOffset = bufferDescriptor->offset; + entry.bufferSize = bufferDescriptor->size; MTL::Buffer *nativeBuffer = static_cast(nativeResource); if (residencySet != nullptr) { std::lock_guard lock(residencySetWriteMutex); @@ -1856,11 +2745,13 @@ namespace plume { } case MTL::DataTypeSampler: { const SamplerDescriptor *samplerDescriptor = static_cast(descriptor); + nativeSampler = samplerDescriptor->state; if (device->useArgumentBuffersTier2) { *reinterpret_cast(bufferPtr + argumentOffset) = samplerDescriptor->state->gpuResourceID(); } else { argumentBuffer.argumentEncoder->setSamplerState(samplerDescriptor->state, argumentIndex); } + samplerDescriptor->state->retain(); break; } @@ -1873,8 +2764,10 @@ namespace plume { argumentBuffer.mtl->didModifyRange(NS::Range(argumentBuffer.offset, argumentBuffer.mtl->length() - argumentBuffer.offset)); } - resourceEntries[descriptorIndex].resource = nativeResource; - resourceEntries[descriptorIndex].type = descriptorType; + entry.resource = nativeResource; + entry.sampler = nativeSampler; + entry.type = descriptorType; + rtBuffersDirty = true; } RenderDescriptorRangeType MetalDescriptorSet::getDescriptorType(const uint32_t binding) const { @@ -2241,6 +3134,20 @@ namespace plume { MetalCommandList::~MetalCommandList() { mtl->release(); + for (auto* buffer : rtDescriptorTableBuffers) { + if (buffer != nullptr) { + buffer->release(); + } + } + for (auto* buffer : rtASHeaderBuffers) { + if (buffer != nullptr) { + buffer->release(); + } + } + if (rtTLABBuffer != nullptr) { + rtTLABBuffer->release(); + } + for (auto& fenceSet : fences) { for (auto* fence : fenceSet) { fence->release(); @@ -2447,7 +3354,353 @@ namespace plume { } void MetalCommandList::traceRays(uint32_t width, uint32_t height, uint32_t depth, RenderBufferReference shaderBindingTable, const RenderShaderBindingGroupsInfo &shaderBindingGroupsInfo) { - // TODO: Support Metal RT +#ifdef PLUME_METAL_RAYTRACING_ENABLED + assert(activeRaytracingPipeline != nullptr && "Must set raytracing pipeline before traceRays"); + assert(activeRaytracingPipelineLayout != nullptr && "Must set raytracing pipeline layout before traceRays"); + + const MetalBuffer *sbtBuffer = static_cast(shaderBindingTable.ref); + assert(sbtBuffer != nullptr && "Shader binding table buffer is null"); + + // End other encoders and start compute encoder for raytracing. + endOtherEncoders(EncoderType::Compute); + activeType = EncoderType::Compute; + + if (activeComputeEncoder == nullptr) { + NS::AutoreleasePool *releasePool = NS::AutoreleasePool::alloc()->init(); + activeComputeEncoder = mtl->computeCommandEncoder(MTL::DispatchTypeConcurrent); + activeComputeEncoder->setLabel(MTLSTR("Raytracing Encoder")); + activeComputeEncoder->retain(); + releasePool->release(); + startedEncoding = true; + barrierWait(MetalBarrierStage::COMPUTE, activeComputeEncoder); + } + + // Set the raytracing compute pipeline. + activeComputeEncoder->setComputePipelineState(activeRaytracingPipeline->computePipeline); + + const auto &rootParams = activeRaytracingPipeline->rootSignatureParameters; + if (rootParams.empty()) { + fprintf(stderr, "Ray tracing root signature data missing. Rebuild shaders with updated Metal RT tools.\n"); + return; + } + + MTL::Device* device = queue->device->mtl; + + if (rtDescriptorTableBuffers.size() > rootParams.size()) { + for (size_t i = rootParams.size(); i < rtDescriptorTableBuffers.size(); i++) { + if (rtDescriptorTableBuffers[i] != nullptr) { + rtDescriptorTableBuffers[i]->release(); + } + } + rtDescriptorTableBuffers.resize(rootParams.size(), nullptr); + rtDescriptorTableBufferSizes.resize(rootParams.size(), 0); + } else if (rtDescriptorTableBuffers.size() < rootParams.size()) { + rtDescriptorTableBuffers.resize(rootParams.size(), nullptr); + rtDescriptorTableBufferSizes.resize(rootParams.size(), 0); + } + + if (rtASHeaderBuffers.size() > rootParams.size()) { + for (size_t i = rootParams.size(); i < rtASHeaderBuffers.size(); i++) { + if (rtASHeaderBuffers[i] != nullptr) { + rtASHeaderBuffers[i]->release(); + } + } + rtASHeaderBuffers.resize(rootParams.size(), nullptr); + rtASHeaderBufferSizes.resize(rootParams.size(), 0); + } else if (rtASHeaderBuffers.size() < rootParams.size()) { + rtASHeaderBuffers.resize(rootParams.size(), nullptr); + rtASHeaderBufferSizes.resize(rootParams.size(), 0); + } + + std::vector tlabAddresses(rootParams.size(), 0); + + auto setDescriptorTableEntry = [&](IRDescriptorTableEntry &tableEntry, const MetalDescriptorSet::ResourceEntry &entry, const MetalDescriptorSetLayout::DescriptorSetLayoutBinding *binding, uint32_t elementIndex) { + if (entry.type == RenderDescriptorRangeType::SAMPLER) { + MTL::SamplerState *sampler = entry.sampler; + if (sampler == nullptr && binding != nullptr && elementIndex < binding->immutableSamplers.size()) { + sampler = binding->immutableSamplers[elementIndex]; + } + if (sampler != nullptr) { + IRDescriptorTableSetSampler(&tableEntry, sampler, 0.0f); + } + return; + } + + if (entry.resource == nullptr) { + return; + } + + switch (entry.type) { + case RenderDescriptorRangeType::TEXTURE: + case RenderDescriptorRangeType::READ_WRITE_TEXTURE: + case RenderDescriptorRangeType::FORMATTED_BUFFER: + case RenderDescriptorRangeType::READ_WRITE_FORMATTED_BUFFER: { + auto *texture = static_cast(entry.resource); + IRDescriptorTableSetTexture(&tableEntry, texture, 0, 0); + break; + } + case RenderDescriptorRangeType::STRUCTURED_BUFFER: + case RenderDescriptorRangeType::READ_WRITE_STRUCTURED_BUFFER: + case RenderDescriptorRangeType::BYTE_ADDRESS_BUFFER: + case RenderDescriptorRangeType::READ_WRITE_BYTE_ADDRESS_BUFFER: + case RenderDescriptorRangeType::CONSTANT_BUFFER: { + auto *buffer = static_cast(entry.resource); + IRBufferView bufferView = {}; + bufferView.buffer = buffer; + bufferView.bufferOffset = entry.bufferOffset; + bufferView.bufferSize = entry.bufferSize; + bufferView.textureBufferView = nullptr; + bufferView.textureViewOffsetInElements = 0; + bufferView.typedBuffer = false; + + if (bufferView.bufferSize == 0) { + bufferView.bufferSize = buffer->length(); + } + + IRDescriptorTableSetBufferView(&tableEntry, &bufferView); + break; + } + default: + break; + } + }; + + auto useResourceEntry = [&](const MetalDescriptorSet::ResourceEntry &entry) { + if (entry.resource != nullptr) { + activeComputeEncoder->useResource(entry.resource, mapResourceUsage(entry.type)); + } + }; + + for (size_t paramIndex = 0; paramIndex < rootParams.size(); paramIndex++) { + const MetalRTRootParameter ¶m = rootParams[paramIndex]; + switch (param.type) { + case MetalRTRootParameterType::DescriptorTable: { + uint32_t totalDescriptors = 0; + for (const auto &range : param.ranges) { + totalDescriptors = std::max(totalDescriptors, range.offset + range.numDescriptors); + } + if (totalDescriptors == 0) { + break; + } + + const size_t tableSize = totalDescriptors * sizeof(IRDescriptorTableEntry); + if (rtDescriptorTableBuffers[paramIndex] == nullptr || rtDescriptorTableBufferSizes[paramIndex] < tableSize) { + if (rtDescriptorTableBuffers[paramIndex] != nullptr) { + rtDescriptorTableBuffers[paramIndex]->release(); + } + rtDescriptorTableBuffers[paramIndex] = device->newBuffer(tableSize, MTL::ResourceStorageModeShared); + rtDescriptorTableBufferSizes[paramIndex] = tableSize; + } + + auto *entries = static_cast(rtDescriptorTableBuffers[paramIndex]->contents()); + memset(entries, 0, tableSize); + + for (const auto &range : param.ranges) { + if (range.numDescriptors == 0) { + continue; + } + if (range.registerSpace >= MAX_DESCRIPTOR_SET_BINDINGS) { + continue; + } + if (range.baseRegister >= MAX_BINDING_NUMBER) { + continue; + } + + MetalDescriptorSet *descriptorSet = raytracingDescriptorSets[range.registerSpace]; + if (descriptorSet == nullptr) { + continue; + } + + const auto *binding = descriptorSet->setLayout->getBinding(range.baseRegister); + const int32_t baseIndex = descriptorSet->setLayout->bindingDescriptorIndexBase[range.baseRegister]; + if (binding == nullptr || baseIndex < 0) { + continue; + } + + const uint32_t rangeCount = std::min(range.numDescriptors, binding->descriptorCount); + for (uint32_t i = 0; i < rangeCount; i++) { + const uint32_t descriptorIndex = static_cast(baseIndex) + i; + if (descriptorIndex >= descriptorSet->resourceEntries.size()) { + continue; + } + auto &entry = descriptorSet->resourceEntries[descriptorIndex]; + setDescriptorTableEntry(entries[range.offset + i], entry, binding, i); + useResourceEntry(entry); + } + } + + tlabAddresses[paramIndex] = rtDescriptorTableBuffers[paramIndex]->gpuAddress(); + activeComputeEncoder->useResource(rtDescriptorTableBuffers[paramIndex], MTL::ResourceUsageRead); + break; + } + case MetalRTRootParameterType::RootCBV: + case MetalRTRootParameterType::RootSRV: + case MetalRTRootParameterType::RootUAV: { + if (param.registerSpace >= MAX_DESCRIPTOR_SET_BINDINGS) { + break; + } + if (param.shaderRegister >= MAX_BINDING_NUMBER) { + break; + } + MetalDescriptorSet *descriptorSet = raytracingDescriptorSets[param.registerSpace]; + if (descriptorSet == nullptr) { + break; + } + + const int32_t baseIndex = descriptorSet->setLayout->bindingDescriptorIndexBase[param.shaderRegister]; + if (baseIndex < 0 || static_cast(baseIndex) >= descriptorSet->resourceEntries.size()) { + break; + } + const auto &entry = descriptorSet->resourceEntries[baseIndex]; + if (entry.type == RenderDescriptorRangeType::ACCELERATION_STRUCTURE) { + auto *tlas = static_cast(entry.resource); + if (tlas == nullptr) { + break; + } + + uint32_t instanceCount = entry.instanceCount; + if (instanceCount == 0) { + instanceCount = 1; + } + + const size_t contributionsSize = instanceCount * sizeof(uint32_t); + const size_t asSize = sizeof(IRRaytracingAccelerationStructureGPUHeader) + contributionsSize; + if (rtASHeaderBuffers[paramIndex] == nullptr || rtASHeaderBufferSizes[paramIndex] < asSize) { + if (rtASHeaderBuffers[paramIndex] != nullptr) { + rtASHeaderBuffers[paramIndex]->release(); + } + rtASHeaderBuffers[paramIndex] = device->newBuffer(asSize, MTL::ResourceStorageModeShared); + rtASHeaderBufferSizes[paramIndex] = asSize; + } + + auto *header = static_cast(rtASHeaderBuffers[paramIndex]->contents()); + header->accelerationStructureID = tlas->gpuResourceID()._impl; + header->addressOfInstanceContributions = rtASHeaderBuffers[paramIndex]->gpuAddress() + sizeof(IRRaytracingAccelerationStructureGPUHeader); + + auto *instanceContributions = reinterpret_cast( + static_cast(rtASHeaderBuffers[paramIndex]->contents()) + sizeof(IRRaytracingAccelerationStructureGPUHeader)); + if (!entry.instanceContributions.empty()) { + const uint32_t copyCount = std::min(instanceCount, static_cast(entry.instanceContributions.size())); + memcpy(instanceContributions, entry.instanceContributions.data(), copyCount * sizeof(uint32_t)); + if (copyCount < instanceCount) { + memset(instanceContributions + copyCount, 0, (instanceCount - copyCount) * sizeof(uint32_t)); + } + } else { + memset(instanceContributions, 0, instanceCount * sizeof(uint32_t)); + } + + tlabAddresses[paramIndex] = rtASHeaderBuffers[paramIndex]->gpuAddress(); + activeComputeEncoder->useResource(tlas, MTL::ResourceUsageRead); + activeComputeEncoder->useResource(rtASHeaderBuffers[paramIndex], MTL::ResourceUsageRead); + } else { + auto *buffer = static_cast(entry.resource); + if (buffer != nullptr) { + tlabAddresses[paramIndex] = buffer->gpuAddress() + entry.bufferOffset; + useResourceEntry(entry); + } + } + break; + } + default: + break; + } + } + + const size_t tlabSize = tlabAddresses.size() * sizeof(uint64_t); + if (rtTLABBuffer == nullptr || rtTLABBufferSize < tlabSize) { + if (rtTLABBuffer != nullptr) { + rtTLABBuffer->release(); + } + rtTLABBuffer = device->newBuffer(tlabSize, MTL::ResourceStorageModeShared); + rtTLABBufferSize = tlabSize; + } + + memcpy(rtTLABBuffer->contents(), tlabAddresses.data(), tlabSize); + activeComputeEncoder->useResource(rtTLABBuffer, MTL::ResourceUsageRead); + activeComputeEncoder->useResource(sbtBuffer->mtl, MTL::ResourceUsageRead); + + // Bind push constants for raytracing. + for (const PushConstantData &pushConstant : pushConstants) { + if (pushConstant.stageFlags & RenderShaderStageFlag::RAYGEN) { + const uint32_t bindIndex = PUSH_CONSTANTS_BINDING_INDEX + pushConstant.binding; + activeComputeEncoder->setBytes(pushConstant.data.data(), pushConstant.size, bindIndex); + } + } + + // Build the IRDispatchRaysDescriptor from the shader binding groups info. + const uint64_t sbtBaseAddress = sbtBuffer->mtl->gpuAddress() + shaderBindingTable.offset; + + const RenderShaderBindingGroupInfo &rayGen = shaderBindingGroupsInfo.rayGen; + const RenderShaderBindingGroupInfo &miss = shaderBindingGroupsInfo.miss; + const RenderShaderBindingGroupInfo &hitGroup = shaderBindingGroupsInfo.hitGroup; + const RenderShaderBindingGroupInfo &callable = shaderBindingGroupsInfo.callable; + + IRDispatchRaysDescriptor dispatchDesc = {}; + + // Ray generation - single record, size equals stride. + dispatchDesc.RayGenerationShaderRecord.StartAddress = (rayGen.size > 0) ? (sbtBaseAddress + rayGen.offset + rayGen.startIndex * rayGen.stride) : 0; + dispatchDesc.RayGenerationShaderRecord.SizeInBytes = rayGen.stride; + + // Miss shader table. + dispatchDesc.MissShaderTable.StartAddress = (miss.size > 0) ? (sbtBaseAddress + miss.offset + miss.startIndex * miss.stride) : 0; + dispatchDesc.MissShaderTable.SizeInBytes = miss.size; + dispatchDesc.MissShaderTable.StrideInBytes = miss.stride; + + // Hit group table. + dispatchDesc.HitGroupTable.StartAddress = (hitGroup.size > 0) ? (sbtBaseAddress + hitGroup.offset + hitGroup.startIndex * hitGroup.stride) : 0; + dispatchDesc.HitGroupTable.SizeInBytes = hitGroup.size; + dispatchDesc.HitGroupTable.StrideInBytes = hitGroup.stride; + + // Callable shader table. + dispatchDesc.CallableShaderTable.StartAddress = (callable.size > 0) ? (sbtBaseAddress + callable.offset + callable.startIndex * callable.stride) : 0; + dispatchDesc.CallableShaderTable.SizeInBytes = callable.size; + dispatchDesc.CallableShaderTable.StrideInBytes = callable.stride; + + // Dispatch dimensions. + dispatchDesc.Width = width; + dispatchDesc.Height = height; + dispatchDesc.Depth = depth; + + // Build the IRDispatchRaysArgument structure. + IRDispatchRaysArgument dispatchArgs = {}; + dispatchArgs.DispatchRaysDesc = dispatchDesc; + dispatchArgs.GRS = rtTLABBuffer->gpuAddress(); // TLAB contains root parameter GPU addresses + dispatchArgs.ResDescHeap = 0; // Not using bindless resource heap + dispatchArgs.SmpDescHeap = 0; // Not using sampler heap + dispatchArgs.VisibleFunctionTable = activeRaytracingPipeline->visibleFunctionTable ? activeRaytracingPipeline->visibleFunctionTable->gpuResourceID() : MTL::ResourceID{0}; + dispatchArgs.IntersectionFunctionTable = activeRaytracingPipeline->intersectionFunctionTable ? activeRaytracingPipeline->intersectionFunctionTable->gpuResourceID() : MTL::ResourceID{0}; + dispatchArgs.IntersectionFunctionTables = 0; + + // Bind the dispatch arguments at the expected bind point. + activeComputeEncoder->setBytes(&dispatchArgs, sizeof(dispatchArgs), kIRRayDispatchArgumentsBindPoint); + + // Bind the visible function table. + if (activeRaytracingPipeline->visibleFunctionTable != nullptr) { + activeComputeEncoder->setVisibleFunctionTable(activeRaytracingPipeline->visibleFunctionTable, 0); + } + + // Bind the intersection function table if present. + if (activeRaytracingPipeline->intersectionFunctionTable != nullptr) { + activeComputeEncoder->setIntersectionFunctionTable(activeRaytracingPipeline->intersectionFunctionTable, 0); + } + + // Calculate thread group size from the pipeline. + NS::UInteger threadWidth = activeRaytracingPipeline->computePipeline->threadExecutionWidth(); + NS::UInteger threadHeight = activeRaytracingPipeline->computePipeline->maxTotalThreadsPerThreadgroup() / threadWidth; + MTL::Size threadGroupSize = MTL::Size(threadWidth, threadHeight, 1); + + // Calculate thread group count. + MTL::Size threadGroupCount = MTL::Size( + (width + threadGroupSize.width - 1) / threadGroupSize.width, + (height + threadGroupSize.height - 1) / threadGroupSize.height, + depth + ); + + activeComputeEncoder->dispatchThreadgroups(threadGroupCount, threadGroupSize); +#else + (void)width; (void)height; (void)depth; (void)shaderBindingTable; (void)shaderBindingGroupsInfo; + fprintf(stderr, "Ray tracing not supported: Metal Shader Converter runtime not available.\n"); +#endif } void MetalCommandList::prepareClearVertices(const RenderRect& rect, simd::float2* outVertices) { @@ -2518,6 +3771,11 @@ namespace plume { } break; } + case MetalPipeline::Type::Raytracing: { + const MetalRaytracingPipeline *rtPipeline = static_cast(interfacePipeline); + activeRaytracingPipeline = rtPipeline; + break; + } default: assert(false && "Unknown pipeline type."); break; @@ -2637,15 +3895,38 @@ namespace plume { } void MetalCommandList::setRaytracingPipelineLayout(const RenderPipelineLayout *pipelineLayout) { - // TODO: Metal RT + assert(pipelineLayout != nullptr); + + const MetalPipelineLayout *oldLayout = activeRaytracingPipelineLayout; + activeRaytracingPipelineLayout = static_cast(pipelineLayout); + + if (oldLayout != activeRaytracingPipelineLayout) { + // Clear descriptor set bindings since they're no longer valid with the new layout. + for (uint32_t i = 0; i < MAX_DESCRIPTOR_SET_BINDINGS; i++) { + raytracingDescriptorSets[i] = nullptr; + } + } } void MetalCommandList::setRaytracingPushConstants(uint32_t rangeIndex, const void *data, uint32_t offset, uint32_t size) { - // TODO: Metal RT + // Push constants for raytracing are handled similarly to compute. + // Store them for binding during traceRays. + assert(activeRaytracingPipelineLayout != nullptr); + assert(rangeIndex < activeRaytracingPipelineLayout->pushConstantRanges.size()); + + const RenderPushConstantRange &range = activeRaytracingPipelineLayout->pushConstantRanges[rangeIndex]; + const uint32_t rangeSize = (size > 0) ? size : range.size; + + if (pushConstants.size() < offset + rangeSize) { + pushConstants.resize(offset + rangeSize); + } + + memcpy(pushConstants.data() + offset, data, rangeSize); } void MetalCommandList::setRaytracingDescriptorSet(RenderDescriptorSet *descriptorSet, uint32_t setIndex) { - // TODO: Metal RT + assert(setIndex < MAX_DESCRIPTOR_SET_BINDINGS); + raytracingDescriptorSets[setIndex] = static_cast(descriptorSet); } void MetalCommandList::setIndexBuffer(const RenderIndexBufferView *view) { @@ -3064,10 +4345,11 @@ namespace plume { checkActiveBlitEncoder(); activeType = EncoderType::Blit; - const MetalTexture *dst = static_cast(dstTexture); - const MetalTexture *src = static_cast(srcTexture); + // Use getTexture() to support both MetalTexture and MetalDrawable (swapchain textures). + const ExtendedRenderTexture *dst = static_cast(dstTexture); + const ExtendedRenderTexture *src = static_cast(srcTexture); - activeBlitEncoder->copyFromTexture(src->mtl, dst->mtl); + activeBlitEncoder->copyFromTexture(src->getTexture(), dst->getTexture()); } void MetalCommandList::resolveTexture(const RenderTexture *dstTexture, const RenderTexture *srcTexture) { @@ -3169,11 +4451,85 @@ namespace plume { } void MetalCommandList::buildBottomLevelAS(const RenderAccelerationStructure *dstAccelerationStructure, RenderBufferReference scratchBuffer, const RenderBottomLevelASBuildInfo &buildInfo) { - // TODO: Unimplemented. + assert(dstAccelerationStructure != nullptr); + assert(scratchBuffer.ref != nullptr); + assert(!buildInfo.buildData.empty()); + + const MetalAccelerationStructure *dstAS = static_cast(dstAccelerationStructure); + const MetalBuffer *scratch = static_cast(scratchBuffer.ref); + + // Retrieve the descriptor stored during setBottomLevelASBuildInfo. + MTL::PrimitiveAccelerationStructureDescriptor *primDesc = nullptr; + memcpy(&primDesc, buildInfo.buildData.data(), sizeof(void *)); + assert(primDesc != nullptr); + + // End any other encoders and create an acceleration structure encoder. + endOtherEncoders(EncoderType::None); + + MTL::AccelerationStructureCommandEncoder *asEncoder = mtl->accelerationStructureCommandEncoder(); + if (asEncoder == nullptr) { + fprintf(stderr, "Failed to create acceleration structure command encoder.\n"); + return; + } + + // Build the acceleration structure. + asEncoder->buildAccelerationStructure(dstAS->mtl, primDesc, scratch->mtl, scratchBuffer.offset); + + asEncoder->endEncoding(); + + // Release the descriptor that was retained during setBottomLevelASBuildInfo. + primDesc->release(); } void MetalCommandList::buildTopLevelAS(const RenderAccelerationStructure *dstAccelerationStructure, RenderBufferReference scratchBuffer, RenderBufferReference instancesBuffer, const RenderTopLevelASBuildInfo &buildInfo) { - // TODO: Unimplemented. + assert(dstAccelerationStructure != nullptr); + assert(scratchBuffer.ref != nullptr); + assert(instancesBuffer.ref != nullptr); + assert(!buildInfo.buildData.empty()); + + MetalAccelerationStructure *dstAS = const_cast( + static_cast(dstAccelerationStructure)); + const MetalBuffer *scratch = static_cast(scratchBuffer.ref); + const MetalBuffer *instances = static_cast(instancesBuffer.ref); + + // Retrieve the descriptor stored during setTopLevelASBuildInfo. + MTL::InstanceAccelerationStructureDescriptor *instDesc = nullptr; + memcpy(&instDesc, buildInfo.buildData.data(), sizeof(void *)); + assert(instDesc != nullptr); + + // Store instance count and extract contributions from instancesBufferData. + // The contributions are stored in intersectionFunctionTableOffset of each instance descriptor. + dstAS->instanceCount = static_cast(instDesc->instanceCount()); + if (dstAS->instanceCount > 0 && !buildInfo.instancesBufferData.empty()) { + const auto* bufferInstances = reinterpret_cast( + buildInfo.instancesBufferData.data()); + dstAS->instanceContributions.resize(dstAS->instanceCount); + for (uint32_t i = 0; i < dstAS->instanceCount; i++) { + dstAS->instanceContributions[i] = bufferInstances[i].intersectionFunctionTableOffset; + } + } + + // Set the instance descriptor buffer. + instDesc->setInstanceDescriptorBuffer(instances->mtl); + instDesc->setInstanceDescriptorBufferOffset(instancesBuffer.offset); + instDesc->setInstanceDescriptorStride(sizeof(MTL::AccelerationStructureUserIDInstanceDescriptor)); + + // End any other encoders and create an acceleration structure encoder. + endOtherEncoders(EncoderType::None); + + MTL::AccelerationStructureCommandEncoder *asEncoder = mtl->accelerationStructureCommandEncoder(); + if (asEncoder == nullptr) { + fprintf(stderr, "Failed to create acceleration structure command encoder.\n"); + return; + } + + // Build the acceleration structure. + asEncoder->buildAccelerationStructure(dstAS->mtl, instDesc, scratch->mtl, scratchBuffer.offset); + + asEncoder->endEncoding(); + + // Release the descriptor that was retained during setTopLevelASBuildInfo. + instDesc->release(); } void MetalCommandList::discardTexture(const RenderTexture* texture) { @@ -3809,8 +5165,7 @@ namespace plume { // Fill capabilities. // https://developer.apple.com/documentation/metal/device-inspection - // TODO: Support Raytracing. - // capabilities.raytracing = mtl->supportsRaytracing(); + capabilities.raytracing = mtl->supportsRaytracing(); capabilities.maxTextureSize = mtl->supportsFamily(MTL::GPUFamilyApple3) ? 16384 : 8192; capabilities.sampleLocations = mtl->programmableSamplePositionsSupported(); capabilities.resolveModes = false; @@ -3887,8 +5242,7 @@ namespace plume { } std::unique_ptr MetalDevice::createRaytracingPipeline(const RenderRaytracingPipelineDesc &desc, const RenderPipeline *previousPipeline) { - // TODO: Unimplemented (Raytracing). - return nullptr; + return std::make_unique(this, desc, previousPipeline); } std::unique_ptr MetalDevice::createCommandQueue(RenderCommandListType type) { @@ -3932,15 +5286,246 @@ namespace plume { } void MetalDevice::setBottomLevelASBuildInfo(RenderBottomLevelASBuildInfo &buildInfo, const RenderBottomLevelASMesh *meshes, uint32_t meshCount, bool preferFastBuild, bool preferFastTrace) { - // TODO: Unimplemented (Raytracing). + assert(meshes != nullptr); + assert(meshCount > 0); + + // Create geometry descriptors for each mesh using a temporary vector. + std::vector geometryDescs; + geometryDescs.reserve(meshCount); + uint32_t primitiveCount = 0; + + for (uint32_t i = 0; i < meshCount; i++) { + const RenderBottomLevelASMesh &mesh = meshes[i]; + + auto *geometryDesc = MTL::AccelerationStructureTriangleGeometryDescriptor::alloc()->init(); + + // Set vertex buffer. + const MetalBuffer *vertexBuffer = static_cast(mesh.vertexBuffer.ref); + if (vertexBuffer != nullptr) { + geometryDesc->setVertexBuffer(vertexBuffer->mtl); + geometryDesc->setVertexBufferOffset(mesh.vertexBuffer.offset); + geometryDesc->setVertexStride(mesh.vertexStride); + geometryDesc->setVertexFormat(mapAttributeFormat(mesh.vertexFormat)); + } + + // Set index buffer if present. + const MetalBuffer *indexBuffer = static_cast(mesh.indexBuffer.ref); + uint32_t triangleCount = 0; + if (indexBuffer != nullptr) { + geometryDesc->setIndexBuffer(indexBuffer->mtl); + geometryDesc->setIndexBufferOffset(mesh.indexBuffer.offset); + geometryDesc->setIndexType(mapIndexFormat(mesh.indexFormat)); + triangleCount = mesh.indexCount / 3; + } else { + triangleCount = mesh.vertexCount / 3; + } + + geometryDesc->setTriangleCount(triangleCount); + geometryDesc->setOpaque(mesh.isOpaque); + + geometryDescs.push_back(geometryDesc); + primitiveCount += triangleCount; + } + + // Create NS::Array from the geometry descriptors. + NS::Array *geometryDescriptors = NS::Array::array( + reinterpret_cast(geometryDescs.data()), + static_cast(geometryDescs.size()) + ); + + // Create primitive acceleration structure descriptor. + auto *primDesc = MTL::PrimitiveAccelerationStructureDescriptor::alloc()->init(); + primDesc->setGeometryDescriptors(geometryDescriptors); + + // Set usage flags. + MTL::AccelerationStructureUsage usage = MTL::AccelerationStructureUsageNone; + if (preferFastBuild) { + usage |= MTL::AccelerationStructureUsagePreferFastBuild; + } + primDesc->setUsage(usage); + + // Query sizes from the device. + MTL::AccelerationStructureSizes sizes = mtl->accelerationStructureSizes(primDesc); + + // Fill build info. + buildInfo.meshCount = meshCount; + buildInfo.primitiveCount = primitiveCount; + buildInfo.preferFastBuild = preferFastBuild; + buildInfo.preferFastTrace = preferFastTrace; + buildInfo.scratchSize = sizes.buildScratchBufferSize; + buildInfo.accelerationStructureSize = sizes.accelerationStructureSize; + + // Store the descriptor in buildData for use during the actual build. + // We store the pointer as raw bytes (the descriptor is retained). + primDesc->retain(); + buildInfo.buildData.resize(sizeof(void *)); + memcpy(buildInfo.buildData.data(), &primDesc, sizeof(void *)); + + // Release geometry descriptors (the primDesc retains the array). + for (auto *desc : geometryDescs) { + desc->release(); + } } void MetalDevice::setTopLevelASBuildInfo(RenderTopLevelASBuildInfo &buildInfo, const RenderTopLevelASInstance *instances, uint32_t instanceCount, bool preferFastBuild, bool preferFastTrace) { - // TODO: Unimplemented (Raytracing). + assert(instances != nullptr); + assert(instanceCount > 0); + + // Build the instance descriptor buffer data. + // Metal uses MTL::AccelerationStructureInstanceDescriptor or the UserID variant. + buildInfo.instancesBufferData.resize(sizeof(MTL::AccelerationStructureUserIDInstanceDescriptor) * instanceCount, 0); + auto *bufferInstances = reinterpret_cast(buildInfo.instancesBufferData.data()); + + // Collect BLAS references for the instanced acceleration structures array. + std::vector blasArray; + blasArray.reserve(instanceCount); + + for (uint32_t i = 0; i < instanceCount; i++) { + const RenderTopLevelASInstance &instance = instances[i]; + MTL::AccelerationStructureUserIDInstanceDescriptor &desc = bufferInstances[i]; + + // Copy transform (3x4 row-major matrix). + // RenderAffineTransform is a 3x4 matrix stored as float[3][4]. + // Metal's PackedFloat4x3 is column-major, so we need to transpose. + for (int row = 0; row < 3; row++) { + for (int col = 0; col < 4; col++) { + desc.transformationMatrix.columns[col][row] = instance.transform.m[row][col]; + } + } + + // Set instance options. + MTL::AccelerationStructureInstanceOptions options = MTL::AccelerationStructureInstanceOptionNone; + if (instance.cullDisable) { + options |= MTL::AccelerationStructureInstanceOptionDisableTriangleCulling; + } + desc.options = options; + + desc.mask = instance.instanceMask; + desc.intersectionFunctionTableOffset = instance.instanceContributionToHitGroupIndex; + desc.userID = instance.instanceID; + + // The accelerationStructureIndex references into the BLAS array we build. + desc.accelerationStructureIndex = i; + + // Get the BLAS from the bottomLevelAS reference. + const MetalAccelerationStructure *blas = static_cast(instance.bottomLevelAS); + if (blas != nullptr && blas->mtl != nullptr) { + blasArray.push_back(blas->mtl); + } else { + blasArray.push_back(nullptr); + } + } + + // Create instance acceleration structure descriptor to query sizes. + auto *instDesc = MTL::InstanceAccelerationStructureDescriptor::alloc()->init(); + instDesc->setInstanceCount(instanceCount); + instDesc->setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorTypeUserID); + + // Set the instanced acceleration structures array. + if (!blasArray.empty()) { + NS::Array *blasNSArray = NS::Array::array( + reinterpret_cast(blasArray.data()), + static_cast(blasArray.size()) + ); + instDesc->setInstancedAccelerationStructures(blasNSArray); + } + + // Set usage flags. + MTL::AccelerationStructureUsage usage = MTL::AccelerationStructureUsageNone; + if (preferFastBuild) { + usage |= MTL::AccelerationStructureUsagePreferFastBuild; + } + instDesc->setUsage(usage); + + // Query sizes from the device. + MTL::AccelerationStructureSizes sizes = mtl->accelerationStructureSizes(instDesc); + + // Fill build info. + buildInfo.instanceCount = instanceCount; + buildInfo.preferFastBuild = preferFastBuild; + buildInfo.preferFastTrace = preferFastTrace; + buildInfo.scratchSize = sizes.buildScratchBufferSize; + buildInfo.accelerationStructureSize = sizes.accelerationStructureSize; + + // Store the descriptor pointer in buildData for use during the actual build. + instDesc->retain(); + buildInfo.buildData.resize(sizeof(void *)); + memcpy(buildInfo.buildData.data(), &instDesc, sizeof(void *)); } void MetalDevice::setShaderBindingTableInfo(RenderShaderBindingTableInfo &tableInfo, const RenderShaderBindingGroups &groups, const RenderPipeline *pipeline, RenderDescriptorSet **descriptorSets, uint32_t descriptorSetCount) { - // TODO: Unimplemented (Raytracing). +#ifdef PLUME_METAL_RAYTRACING_ENABLED + assert(pipeline != nullptr); + const MetalRaytracingPipeline *rtPipeline = static_cast(pipeline); + + // IRShaderIdentifier is 32 bytes (4 x uint64_t). + // Each shader record = IRShaderIdentifier + local root signature data. + // For simplicity, we use a fixed stride that accommodates the identifier. + // Local root signature data would follow the identifier if needed. + constexpr uint32_t shaderIdentifierSize = sizeof(IRShaderIdentifier); + + // Calculate stride for each group (must be aligned to D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT = 32). + constexpr uint32_t recordAlignment = 32; + auto alignStride = [](uint32_t size) -> uint32_t { + return (size + recordAlignment - 1) & ~(recordAlignment - 1); + }; + + // Helper to set up a group's info. + auto setGroup = [&](RenderShaderBindingGroupInfo &groupInfo, const RenderShaderBindingGroup &renderGroup, uint32_t ¤tOffset) { + if (renderGroup.pipelineProgramsCount == 0) { + groupInfo.offset = 0; + groupInfo.size = 0; + groupInfo.stride = 0; + groupInfo.startIndex = 0; + return; + } + + groupInfo.stride = alignStride(shaderIdentifierSize); + groupInfo.offset = currentOffset; + groupInfo.size = groupInfo.stride * renderGroup.pipelineProgramsCount; + groupInfo.startIndex = 0; + + currentOffset += groupInfo.size; + }; + + // Calculate total size and set up group info. + uint32_t currentOffset = 0; + setGroup(tableInfo.groups.rayGen, groups.rayGen, currentOffset); + setGroup(tableInfo.groups.miss, groups.miss, currentOffset); + setGroup(tableInfo.groups.hitGroup, groups.hitGroup, currentOffset); + setGroup(tableInfo.groups.callable, groups.callable, currentOffset); + + // Allocate table buffer data. + tableInfo.tableBufferData.resize(currentOffset); + memset(tableInfo.tableBufferData.data(), 0, currentOffset); + + // Helper to copy shader identifiers into the table. + auto copyGroupData = [&](const RenderShaderBindingGroupInfo &groupInfo, const RenderShaderBindingGroup &renderGroup) { + if (renderGroup.pipelineProgramsCount == 0) { + return; + } + + uint8_t *groupData = tableInfo.tableBufferData.data() + groupInfo.offset; + + for (uint32_t i = 0; i < renderGroup.pipelineProgramsCount; i++) { + const RenderPipelineProgram &program = renderGroup.pipelinePrograms[i]; + IRShaderIdentifier *identifier = reinterpret_cast(groupData + i * groupInfo.stride); + + // Initialize the shader identifier. + // program.programIndex is the index into the visible function table (1-based, 0 = null). + IRShaderIdentifierInit(identifier, program.programIndex); + } + }; + + // Copy shader identifiers for each group. + copyGroupData(tableInfo.groups.rayGen, groups.rayGen); + copyGroupData(tableInfo.groups.miss, groups.miss); + copyGroupData(tableInfo.groups.hitGroup, groups.hitGroup); + copyGroupData(tableInfo.groups.callable, groups.callable); +#else + (void)tableInfo; (void)groups; (void)pipeline; (void)descriptorSets; (void)descriptorSetCount; + fprintf(stderr, "Ray tracing not supported: Metal Shader Converter runtime not available.\n"); +#endif } const RenderDeviceCapabilities &MetalDevice::getCapabilities() const { diff --git a/plume_metal.h b/plume_metal.h index cf55279..f4008d9 100644 --- a/plume_metal.h +++ b/plume_metal.h @@ -54,6 +54,7 @@ namespace plume { struct MetalBufferFormattedView; struct MetalPipelineLayout; struct MetalGraphicsPipeline; + struct MetalRaytracingPipeline; struct MetalPool; struct MetalDrawable; @@ -146,6 +147,7 @@ namespace plume { struct BufferDescriptor: Descriptor { MTL::Buffer *buffer; uint32_t offset = 0; + uint64_t size = 0; }; struct TextureDescriptor: Descriptor { @@ -167,6 +169,7 @@ namespace plume { MetalDevice *device = nullptr; std::vector setBindings; std::vector bindingToIndex; + std::vector bindingDescriptorIndexBase; MTL::ArgumentEncoder *argumentEncoder = nullptr; std::vector argumentDescriptors; std::vector descriptorIndexBases; @@ -208,7 +211,12 @@ namespace plume { struct MetalDescriptorSet : RenderDescriptorSet { struct ResourceEntry { MTL::Resource* resource = nullptr; + MTL::SamplerState* sampler = nullptr; RenderDescriptorRangeType type = RenderDescriptorRangeType::UNKNOWN; + uint64_t bufferOffset = 0; + uint64_t bufferSize = 0; + uint32_t instanceCount = 0; // For acceleration structures: TLAS instance count + std::vector instanceContributions; // For acceleration structures: per-instance hit group offsets }; MetalDevice *device = nullptr; @@ -220,6 +228,8 @@ namespace plume { std::mutex residencySetWriteMutex; bool needsCommit = false; + bool rtBuffersDirty = true; // Set when resources change, cleared after TLAB rebuild + MetalDescriptorSet(MetalDevice *device, const RenderDescriptorSetDesc &desc); MetalDescriptorSet(MetalDevice *device, uint32_t entryCount); ~MetalDescriptorSet() override; @@ -446,10 +456,20 @@ namespace plume { const MetalFramebuffer *targetFramebuffer = nullptr; const MetalPipelineLayout *activeComputePipelineLayout = nullptr; const MetalPipelineLayout *activeGraphicsPipelineLayout = nullptr; + const MetalPipelineLayout *activeRaytracingPipelineLayout = nullptr; const MetalRenderState *activeRenderState = nullptr; const MetalComputeState *activeComputeState = nullptr; + const MetalRaytracingPipeline *activeRaytracingPipeline = nullptr; MetalDescriptorSet* renderDescriptorSets[MAX_DESCRIPTOR_SET_BINDINGS] = {}; MetalDescriptorSet* computeDescriptorSets[MAX_DESCRIPTOR_SET_BINDINGS] = {}; + MetalDescriptorSet* raytracingDescriptorSets[MAX_DESCRIPTOR_SET_BINDINGS] = {}; + + std::vector rtDescriptorTableBuffers; + std::vector rtDescriptorTableBufferSizes; + std::vector rtASHeaderBuffers; + std::vector rtASHeaderBufferSizes; + MTL::Buffer* rtTLABBuffer = nullptr; + size_t rtTLABBufferSize = 0; MTL::Fence *timestampQueryFence = nullptr; @@ -617,10 +637,11 @@ namespace plume { struct MetalAccelerationStructure : RenderAccelerationStructure { MetalDevice *device = nullptr; - const MetalBuffer *buffer = nullptr; - uint64_t offset = 0; + MTL::AccelerationStructure *mtl = nullptr; uint64_t size = 0; RenderAccelerationStructureType type = RenderAccelerationStructureType::UNKNOWN; + uint32_t instanceCount = 0; // For TLAS: number of instances (set during build) + std::vector instanceContributions; // For TLAS: per-instance hit group offsets MetalAccelerationStructure(MetalDevice *device, const RenderAccelerationStructureDesc &desc); ~MetalAccelerationStructure() override; @@ -640,7 +661,9 @@ namespace plume { NS::String *functionName = nullptr; RenderShaderFormat format = RenderShaderFormat::UNKNOWN; MTL::Library *library = nullptr; + MTL::Library *dispatchLibrary = nullptr; // For combined RT shaders: dispatch kernel library NS::String *debugName = nullptr; + std::string rtRootSignatureJson; MetalShader(const MetalDevice *device, const void *data, uint64_t size, const char *entryPointName, RenderShaderFormat format); ~MetalShader() override; @@ -648,6 +671,36 @@ namespace plume { MTL::Function* createFunction(const RenderSpecConstant *specConstants, uint32_t specConstantsCount) const; }; + enum class MetalRTDescriptorRangeType { + Unknown, + SRV, + UAV, + Sampler + }; + + enum class MetalRTRootParameterType { + Unknown, + DescriptorTable, + RootCBV, + RootSRV, + RootUAV + }; + + struct MetalRTDescriptorRange { + MetalRTDescriptorRangeType type = MetalRTDescriptorRangeType::Unknown; + uint32_t baseRegister = 0; + uint32_t registerSpace = 0; + uint32_t numDescriptors = 0; + uint32_t offset = 0; + }; + + struct MetalRTRootParameter { + MetalRTRootParameterType type = MetalRTRootParameterType::Unknown; + uint32_t shaderRegister = 0; + uint32_t registerSpace = 0; + std::vector ranges; + }; + struct MetalSampler : RenderSampler { MTL::SamplerState *state = nullptr; RenderBorderColor borderColor = RenderBorderColor::UNKNOWN; @@ -689,6 +742,23 @@ namespace plume { RenderPipelineProgram getProgram(const std::string &name) const override; }; + struct MetalRaytracingPipeline : MetalPipeline { + const MetalDevice *device = nullptr; + MTL::ComputePipelineState *computePipeline = nullptr; + MTL::VisibleFunctionTable *visibleFunctionTable = nullptr; + MTL::IntersectionFunctionTable *intersectionFunctionTable = nullptr; + std::unordered_map nameProgramMap; + const MetalPipelineLayout *pipelineLayout = nullptr; + std::vector rootSignatureParameters; + uint32_t maxPayloadSize = 0; + uint32_t maxAttributeSize = 0; + + MetalRaytracingPipeline(MetalDevice *device, const RenderRaytracingPipelineDesc &desc, const RenderPipeline *previousPipeline); + ~MetalRaytracingPipeline() override; + void setName(const std::string &name) override; + RenderPipelineProgram getProgram(const std::string &name) const override; + }; + struct MetalPipelineLayout : RenderPipelineLayout { std::vector pushConstantRanges; uint32_t setLayoutCount = 0; diff --git a/plume_render_interface_types.h b/plume_render_interface_types.h index e673b89..5d74d43 100644 --- a/plume_render_interface_types.h +++ b/plume_render_interface_types.h @@ -1648,8 +1648,10 @@ namespace plume { std::vector buildData; }; + struct RenderAccelerationStructure; + struct RenderTopLevelASInstance { - RenderBufferReference bottomLevelAS; + const RenderAccelerationStructure *bottomLevelAS = nullptr; uint32_t instanceID = 0; uint32_t instanceMask = 0; uint32_t instanceContributionToHitGroupIndex = 0; @@ -1658,7 +1660,7 @@ namespace plume { RenderTopLevelASInstance() = default; - RenderTopLevelASInstance(RenderBufferReference bottomLevelAS, uint32_t instanceID, uint32_t instanceMask, uint32_t instanceContributionToHitGroupIndex, bool cullDisable, RenderAffineTransform transform) { + RenderTopLevelASInstance(const RenderAccelerationStructure *bottomLevelAS, uint32_t instanceID, uint32_t instanceMask, uint32_t instanceContributionToHitGroupIndex, bool cullDisable, RenderAffineTransform transform) { this->bottomLevelAS = bottomLevelAS; this->instanceID = instanceID; this->instanceMask = instanceMask; diff --git a/plume_vulkan.cpp b/plume_vulkan.cpp index d73fb1c..64fcd7e 100644 --- a/plume_vulkan.cpp +++ b/plume_vulkan.cpp @@ -6,7 +6,7 @@ // #define VMA_IMPLEMENTATION -#define VOLK_IMPLEMENTATION +#define VOLK_IMPLEMENTATION #include "plume_vulkan.h" @@ -65,11 +65,11 @@ namespace plume { VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME, # endif }; - + static const std::unordered_set RequiredDeviceExtensions = { VK_KHR_SWAPCHAIN_EXTENSION_NAME, }; - + static const std::unordered_set OptionalDeviceExtensions = { VK_EXT_DESCRIPTOR_INDEXING_EXTENSION_NAME, VK_EXT_SCALAR_BLOCK_LAYOUT_EXTENSION_NAME, @@ -342,9 +342,9 @@ namespace plume { case RenderPrimitiveTopology::LINE_STRIP: return VK_PRIMITIVE_TOPOLOGY_LINE_STRIP; case RenderPrimitiveTopology::TRIANGLE_LIST: - return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST; + return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST; case RenderPrimitiveTopology::TRIANGLE_STRIP: - return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP; + return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP; case RenderPrimitiveTopology::TRIANGLE_FAN: return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_FAN; default: @@ -599,7 +599,7 @@ namespace plume { return VK_ACCELERATION_STRUCTURE_TYPE_MAX_ENUM_KHR; } } - + static VkPipelineStageFlags toStageFlags(RenderBarrierStages stages, bool geometrySupported, bool rtSupported) { VkPipelineStageFlags flags = 0; @@ -688,7 +688,7 @@ namespace plume { flags |= preferFastTrace ? VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_TRACE_BIT_KHR : 0; return flags; } - + static VkImageLayout toImageLayout(RenderTextureLayout layout) { switch (layout) { case RenderTextureLayout::UNKNOWN: @@ -918,7 +918,7 @@ namespace plume { } void VulkanBuffer::setName(const std::string &name) { - setObjectName(device->vk, VK_OBJECT_TYPE_IMAGE, uint64_t(vk), name); + setObjectName(device->vk, VK_OBJECT_TYPE_BUFFER, uint64_t(vk), name); } uint64_t VulkanBuffer::getDeviceAddress() const { @@ -1031,7 +1031,7 @@ namespace plume { vmaDestroyImage(device->allocator, vk, allocation); } } - + void VulkanTexture::createImageView(VkFormat format) { VkImageView view = VK_NULL_HANDLE; VkImageViewCreateInfo viewInfo = {}; @@ -1044,7 +1044,7 @@ namespace plume { viewInfo.components.b = VK_COMPONENT_SWIZZLE_IDENTITY; viewInfo.components.a = VK_COMPONENT_SWIZZLE_IDENTITY; viewInfo.subresourceRange = imageSubresourceRange; - + VkResult res = vkCreateImageView(device->vk, &viewInfo, nullptr, &imageView); if (res != VK_SUCCESS) { fprintf(stderr, "vkCreateImageView failed with error code 0x%X.\n", res); @@ -1119,6 +1119,9 @@ namespace plume { this->type = desc.type; const VulkanBuffer *interfaceBuffer = static_cast(desc.buffer.ref); + this->backingBuffer = interfaceBuffer; + this->backingBufferOffset = desc.buffer.offset; + VkAccelerationStructureCreateInfoKHR createInfo = {}; createInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR; createInfo.buffer = interfaceBuffer->vk; @@ -1160,7 +1163,7 @@ namespace plume { } } } - + // Create bindings. uint32_t immutableSamplerIndex = 0; for (uint32_t i = 0; i < descriptorSetDesc.descriptorRangesCount; i++) { @@ -1189,7 +1192,7 @@ namespace plume { setLayoutInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; setLayoutInfo.pBindings = !setBindings.empty() ? setBindings.data() : nullptr; setLayoutInfo.bindingCount = uint32_t(setBindings.size()); - + thread_local std::vector bindingFlags; VkDescriptorSetLayoutBindingFlagsCreateInfo flagsInfo = {}; if (descriptorSetDesc.lastRangeIsBoundless && (descriptorSetDesc.descriptorRangesCount > 0)) { @@ -1204,7 +1207,7 @@ namespace plume { setLayoutInfo.pNext = &flagsInfo; setLayoutInfo.flags = VK_DESCRIPTOR_SET_LAYOUT_CREATE_UPDATE_AFTER_BIND_POOL_BIT; } - + VkResult res = vkCreateDescriptorSetLayout(device->vk, &setLayoutInfo, nullptr, &vk); if (res != VK_SUCCESS) { fprintf(stderr, "vkCreateDescriptorSetLayout failed with error code 0x%X.\n", res); @@ -1580,7 +1583,7 @@ namespace plume { colorBlend.logicOp = toVk(desc.logicOp); colorBlend.pAttachments = !colorBlendAttachments.empty() ? colorBlendAttachments.data() : nullptr; colorBlend.attachmentCount = uint32_t(colorBlendAttachments.size()); - + VkPipelineDepthStencilStateCreateInfo depthStencil = {}; depthStencil.sType = VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO; depthStencil.depthTestEnable = desc.depthEnabled; @@ -1860,7 +1863,7 @@ namespace plume { groupCount = pipelineInfo.groupCount; } - + VulkanRaytracingPipeline::~VulkanRaytracingPipeline() { if (vk != VK_NULL_HANDLE) { vkDestroyPipeline(device->vk, vk, nullptr); @@ -1886,7 +1889,7 @@ namespace plume { thread_local std::unordered_map typeCounts; typeCounts.clear(); - + uint32_t boundlessRangeSize = 0; uint32_t rangeCount = desc.descriptorRangesCount; if (desc.lastRangeIsBoundless) { @@ -1940,7 +1943,7 @@ namespace plume { delete setLayout; } - + void VulkanDescriptorSet::setBuffer(uint32_t descriptorIndex, const RenderBuffer *buffer, uint64_t bufferSize, const RenderBufferStructuredView *bufferStructuredView, const RenderBufferFormattedView *bufferFormattedView) { if (buffer == nullptr) { return; @@ -2132,7 +2135,7 @@ namespace plume { assert(renderWindow.view != 0); // Creates a wrapper around the window for storing and fetching sizes. this->windowWrapper = std::make_unique(renderWindow.window); - + VkMetalSurfaceCreateInfoEXT surfaceCreateInfo = {}; surfaceCreateInfo.sType = VK_STRUCTURE_TYPE_METAL_SURFACE_CREATE_INFO_EXT; surfaceCreateInfo.pLayer = renderWindow.view; @@ -2275,7 +2278,7 @@ namespace plume { presentId.swapchainCount = 1; presentInfo.pNext = &presentId; } - + VkResult res; { const std::scoped_lock queueLock(*commandQueue->queue->mutex); @@ -2314,17 +2317,23 @@ namespace plume { // Destroy any image view references to the current swap chain. releaseImageViews(); - // We don't actually need to query the surface capabilities but the validation layer seems to cache the valid extents from this call. + // Query surface capabilities to get the valid extent bounds. VkSurfaceCapabilitiesKHR surfaceCapabilities = {}; vkGetPhysicalDeviceSurfaceCapabilitiesKHR(commandQueue->device->physicalDevice, surface, &surfaceCapabilities); + // Clamp the extent to the surface capabilities' min/max bounds. + // This is required because the window size may differ from the valid surface extent + // (e.g., due to window decorations, compositor behavior, or timing issues). + uint32_t clampedWidth = std::clamp(width, surfaceCapabilities.minImageExtent.width, surfaceCapabilities.maxImageExtent.width); + uint32_t clampedHeight = std::clamp(height, surfaceCapabilities.minImageExtent.height, surfaceCapabilities.maxImageExtent.height); + createInfo.sType = VK_STRUCTURE_TYPE_SWAPCHAIN_CREATE_INFO_KHR; createInfo.surface = surface; createInfo.minImageCount = textureCount; createInfo.imageFormat = pickedSurfaceFormat.format; createInfo.imageColorSpace = pickedSurfaceFormat.colorSpace; - createInfo.imageExtent.width = width; - createInfo.imageExtent.height = height; + createInfo.imageExtent.width = clampedWidth; + createInfo.imageExtent.height = clampedHeight; createInfo.imageArrayLayers = 1; createInfo.imageUsage = VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT; createInfo.imageSharingMode = VK_SHARING_MODE_EXCLUSIVE; @@ -2371,12 +2380,16 @@ namespace plume { // Assign the swap chain images to the buffer resources. textures.resize(textureCount); + // Update the stored width/height to reflect the actual swapchain extent. + width = clampedWidth; + height = clampedHeight; + for (uint32_t i = 0; i < textureCount; i++) { textures[i] = VulkanTexture(commandQueue->device, images[i]); textures[i].desc.dimension = RenderTextureDimension::TEXTURE_2D; textures[i].desc.format = format; - textures[i].desc.width = width; - textures[i].desc.height = height; + textures[i].desc.width = clampedWidth; + textures[i].desc.height = clampedHeight; textures[i].desc.depth = 1; textures[i].desc.mipLevels = 1; textures[i].desc.arraySize = 1; @@ -2495,7 +2508,7 @@ namespace plume { } // VulkanFramebuffer - + VulkanFramebuffer::VulkanFramebuffer(VulkanDevice *device, const RenderFramebufferDesc &desc) { assert(device != nullptr); @@ -2614,7 +2627,7 @@ namespace plume { fprintf(stderr, "vkCreateRenderPass failed with error code 0x%X.\n", res); return; } - + VkFramebufferCreateInfo fbInfo = {}; fbInfo.sType = VK_STRUCTURE_TYPE_FRAMEBUFFER_CREATE_INFO; fbInfo.renderPass = renderPass; @@ -2673,13 +2686,13 @@ namespace plume { createInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; createInfo.queryType = VK_QUERY_TYPE_TIMESTAMP; createInfo.queryCount = queryCount; - + VkResult res = vkCreateQueryPool(device->vk, &createInfo, nullptr, &vk); if (res != VK_SUCCESS) { fprintf(stderr, "vkCreateQueryPool failed with error code 0x%X.\n", res); return; } - + results.resize(queryCount); } @@ -2701,16 +2714,16 @@ namespace plume { uint64_t t = (u1 * v1); uint64_t w3 = (t & 0xffffffff); uint64_t k = (t >> 32); - + u >>= 32; t = (u * v1) + k; k = (t & 0xffffffff); uint64_t w1 = (t >> 32); - + v >>= 32; t = (u1 * v) + k; k = (t >> 32); - + h = (u * v) + w1 + k; l = (t << 32) + w3; }; @@ -2861,7 +2874,7 @@ namespace plume { interfaceTexture->textureLayout = textureBarrier.layout; interfaceTexture->barrierStages = stages; } - + if (bufferMemoryBarriers.empty() && imageMemoryBarriers.empty()) { return; } @@ -2915,7 +2928,7 @@ namespace plume { vkCmdDraw(vk, vertexCountPerInstance, instanceCount, startVertexLocation, startInstanceLocation); } - + void VulkanCommandList::drawIndexedInstanced(uint32_t indexCountPerInstance, uint32_t instanceCount, uint32_t startIndexLocation, int32_t baseVertexLocation, uint32_t startInstanceLocation) { assert(activeGraphicsPipelineLayout != nullptr); checkActiveRenderPass(); @@ -2958,7 +2971,7 @@ namespace plume { void VulkanCommandList::setComputePushConstants(uint32_t rangeIndex, const void *data, uint32_t offset, uint32_t size) { assert(activeComputePipelineLayout != nullptr); assert(rangeIndex < activeComputePipelineLayout->pushConstantRanges.size()); - + const VkPushConstantRange &range = activeComputePipelineLayout->pushConstantRanges[rangeIndex]; vkCmdPushConstants(vk, activeComputePipelineLayout->vk, range.stageFlags & VK_SHADER_STAGE_COMPUTE_BIT, range.offset + offset, size == 0 ? range.size : size, data); } @@ -3184,7 +3197,7 @@ namespace plume { void VulkanCommandList::copyTextureRegion(const RenderTextureCopyLocation &dstLocation, const RenderTextureCopyLocation &srcLocation, uint32_t dstX, uint32_t dstY, uint32_t dstZ, const RenderBox *srcBox) { endActiveRenderPass(); - + assert(dstLocation.type != RenderTextureCopyType::UNKNOWN); assert(srcLocation.type != RenderTextureCopyType::UNKNOWN); @@ -3253,7 +3266,7 @@ namespace plume { assert(dstBuffer != nullptr); assert(srcBuffer != nullptr); - + const VulkanBuffer *interfaceDstBuffer = static_cast(dstBuffer); const VulkanBuffer *interfaceSrcBuffer = static_cast(srcBuffer); VkBufferCopy bufferCopy = {}; @@ -3348,7 +3361,7 @@ namespace plume { vkCmdResolveImage(vk, src->vk, srcLayout, dst->vk, dstLayout, uint32_t(imageResolves.size()), imageResolves.data()); } - + void VulkanCommandList::buildBottomLevelAS(const RenderAccelerationStructure *dstAccelerationStructure, RenderBufferReference scratchBuffer, const RenderBottomLevelASBuildInfo &buildInfo) { assert(dstAccelerationStructure != nullptr); assert(scratchBuffer.ref != nullptr); @@ -3452,7 +3465,7 @@ namespace plume { void VulkanCommandList::checkActiveRenderPass() { assert(targetFramebuffer != nullptr); - + if (activeRenderPass == VK_NULL_HANDLE) { VkRenderPassBeginInfo beginInfo = {}; beginInfo.sType = VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO; @@ -3614,7 +3627,7 @@ namespace plume { return; } } - + void VulkanCommandQueue::waitForCommandFence(RenderCommandFence *fence) { assert(fence != nullptr); @@ -3684,7 +3697,7 @@ namespace plume { std::unique_ptr VulkanPool::createTexture(const RenderTextureDesc &desc) { return std::make_unique(device, this, desc); } - + // VulkanQueueFamily void VulkanQueueFamily::add(VulkanCommandQueue *virtualQueue) { @@ -3717,7 +3730,7 @@ namespace plume { } // VulkanDevice - + VulkanDevice::VulkanDevice(VulkanInterface *renderInterface, const std::string &preferredDeviceName) { assert(renderInterface != nullptr); @@ -3729,7 +3742,7 @@ namespace plume { fprintf(stderr, "Unable to find devices that support Vulkan.\n"); return; } - + std::vector physicalDevices(deviceCount); vkEnumeratePhysicalDevices(renderInterface->instance, &deviceCount, physicalDevices.data()); @@ -3800,7 +3813,7 @@ namespace plume { } # endif } - + if (!missingRequiredExtensions.empty()) { for (const std::string &extension : missingRequiredExtensions) { fprintf(stderr, "Missing required extension: %s.\n", extension.c_str()); @@ -3948,7 +3961,7 @@ namespace plume { bufferDeviceAddressFeatures.pNext = createDeviceChain; createDeviceChain = &bufferDeviceAddressFeatures; } - + if (portabilityFound) { portabilityFeatures.pNext = createDeviceChain; createDeviceChain = &portabilityFeatures; @@ -4291,8 +4304,9 @@ namespace plume { VkAccelerationStructureInstanceKHR *bufferInstances = reinterpret_cast(buildInfo.instancesBufferData.data()); for (uint32_t i = 0; i < instanceCount; i++) { const RenderTopLevelASInstance &instance = instances[i]; - const VulkanBuffer *interfaceBottomLevelAS = static_cast(instance.bottomLevelAS.ref); - assert(interfaceBottomLevelAS != nullptr); + const VulkanAccelerationStructure *blasAS = static_cast(instance.bottomLevelAS); + assert(blasAS != nullptr); + assert(blasAS->backingBuffer != nullptr); VkAccelerationStructureInstanceKHR &bufferInstance = bufferInstances[i]; bufferInstance.instanceCustomIndex = instance.instanceID; @@ -4303,8 +4317,8 @@ namespace plume { VkBufferDeviceAddressInfo blasAddressInfo = {}; blasAddressInfo.sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO; - blasAddressInfo.buffer = interfaceBottomLevelAS->vk; - bufferInstance.accelerationStructureReference = vkGetBufferDeviceAddress(vk, &blasAddressInfo) + instance.bottomLevelAS.offset; + blasAddressInfo.buffer = blasAS->backingBuffer->vk; + bufferInstance.accelerationStructureReference = vkGetBufferDeviceAddress(vk, &blasAddressInfo) + blasAS->backingBufferOffset; } // Retrieve the size the TLAS will require. @@ -4333,7 +4347,7 @@ namespace plume { buildInfo.scratchSize = roundUp(buildSizesInfo.buildScratchSize, AccelerationStructureBufferAlignment); buildInfo.accelerationStructureSize = roundUp(buildSizesInfo.accelerationStructureSize, AccelerationStructureBufferAlignment); } - + void VulkanDevice::setShaderBindingTableInfo(RenderShaderBindingTableInfo &tableInfo, const RenderShaderBindingGroups &groups, const RenderPipeline *pipeline, RenderDescriptorSet **descriptorSets, uint32_t descriptorSetCount) { assert(pipeline != nullptr); assert((descriptorSets != nullptr) && "Vulkan doesn't require descriptor sets, but they should be passed to keep consistency with D3D12."); @@ -4351,7 +4365,7 @@ namespace plume { fprintf(stderr, "vkGetRayTracingShaderGroupHandlesKHR failed with error code 0x%X.\n", res); return; } - + const uint32_t handleSizeAligned = roundUp(handleSize, rtPipelineProperties.shaderGroupHandleAlignment); const uint32_t regionAlignment = roundUp(handleSizeAligned, rtPipelineProperties.shaderGroupBaseAlignment); uint64_t tableSize = 0; @@ -4485,7 +4499,7 @@ namespace plume { # if PLUME_SDL_VULKAN_ENABLED // Push the extensions specified by SDL as required. - // SDL2 has this awkward requirement for the window to pull the extensions from. + // SDL2 has this awkward requirement for the window to pull the extensions from. // This can be removed when upgrading to SDL3. if (sdlWindow != nullptr) { uint32_t sdlVulkanExtensionCount = 0; @@ -4544,7 +4558,7 @@ namespace plume { std::vector availableLayers(layerCount); vkEnumerateInstanceLayerProperties(&layerCount, availableLayers.data()); - + const char validationLayerName[] = "VK_LAYER_KHRONOS_validation"; const char *enabledLayerNames[] = { validationLayerName }; for (const VkLayerProperties &layerProperties : availableLayers) { @@ -4555,7 +4569,7 @@ namespace plume { } } # endif - + res = vkCreateInstance(&createInfo, nullptr, &instance); if (res != VK_SUCCESS) { fprintf(stderr, "vkCreateInstance failed with error code 0x%X.\n", res); diff --git a/plume_vulkan.h b/plume_vulkan.h index d3bdbd8..fa25249 100644 --- a/plume_vulkan.h +++ b/plume_vulkan.h @@ -112,6 +112,8 @@ namespace plume { VkAccelerationStructureKHR vk = VK_NULL_HANDLE; VulkanDevice *device = nullptr; RenderAccelerationStructureType type = RenderAccelerationStructureType::UNKNOWN; + const VulkanBuffer *backingBuffer = nullptr; + uint64_t backingBufferOffset = 0; VulkanAccelerationStructure(VulkanDevice *device, const RenderAccelerationStructureDesc &desc); ~VulkanAccelerationStructure() override;