diff --git a/MetalSplatter/Sources/MPSArgSort.swift b/MetalSplatter/Sources/MPSArgSort.swift new file mode 100644 index 00000000..b35fc819 --- /dev/null +++ b/MetalSplatter/Sources/MPSArgSort.swift @@ -0,0 +1,68 @@ +// Original source: https://gist.github.com/kemchenj/26e1dad40e5b89de2828bad36c81302f +// Assessed Feb 2, 2025. +import MetalPerformanceShaders +import MetalPerformanceShadersGraph + +public class MPSArgSort { + private let dataType: MPSDataType + private let graph: MPSGraph + private let graphExecutable: MPSGraphExecutable + private let inputTensor: MPSGraphTensor + private let outputTensor: MPSGraphTensor + + init(dataType: MPSDataType, descending: Bool = false) { + self.dataType = dataType + + let graph = MPSGraph() + let inputTensor = graph.placeholder(shape: nil, dataType: dataType, name: nil) + let outputTensor = graph.argSort(inputTensor, axis: 0, descending: descending, name: nil) + + self.graph = graph + self.inputTensor = inputTensor + self.outputTensor = outputTensor + self.graphExecutable = autoreleasepool { + let compilationDescriptor = MPSGraphCompilationDescriptor() + compilationDescriptor.waitForCompilationCompletion = true + compilationDescriptor.disableTypeInference() + return graph.compile(with: nil, + feeds: [inputTensor : MPSGraphShapedType(shape: nil, dataType: dataType)], + targetTensors: [outputTensor], + targetOperations: nil, + compilationDescriptor: compilationDescriptor) + } + } + + func callAsFunction( + commandQueue: any MTLCommandQueue, + input: any MTLBuffer, + output: any MTLBuffer, + count: Int + ) { + autoreleasepool { + let commandBuffer = commandQueue.makeCommandBuffer()! + callAsFunction(commandBuffer: commandBuffer, + input: input, + output: output, + count: count) + assert(commandBuffer.error == nil) + assert(commandBuffer.status == .completed) + } + } + + private func callAsFunction( + commandBuffer: any MTLCommandBuffer, + input: any MTLBuffer, + output: any MTLBuffer, + count: Int + ) { + let shape: [NSNumber] = [count as NSNumber] + let inputData = MPSGraphTensorData(input, shape: shape, dataType: dataType) + let outputData = MPSGraphTensorData(output, shape: shape, dataType: .int32) + let executionDescriptor = MPSGraphExecutableExecutionDescriptor() + executionDescriptor.waitUntilCompleted = true + graphExecutable.encode(to: MPSCommandBuffer(commandBuffer: commandBuffer), + inputs: [inputData], + results: [outputData], + executionDescriptor: executionDescriptor) + } +} diff --git a/MetalSplatter/Sources/SplatRenderer.swift b/MetalSplatter/Sources/SplatRenderer.swift index ea586ee7..f8de8ea5 100644 --- a/MetalSplatter/Sources/SplatRenderer.swift +++ b/MetalSplatter/Sources/SplatRenderer.swift @@ -30,7 +30,9 @@ public class SplatRenderer { private static let log = Logger(subsystem: Bundle.module.bundleIdentifier!, category: "SplatRenderer") - + + private var computeDepthsPipelineState: MTLComputePipelineState? + public struct ViewportDescriptor { public var viewport: MTLViewport public var projectionMatrix: simd_float4x4 @@ -570,54 +572,141 @@ public class SplatRenderer { } // Sort splatBuffer (read-only), storing the results in splatBuffer (write-only) then swap splatBuffer and splatBufferPrime - public func resort() { + public func resort(useGPU: Bool = true) { guard !sorting else { return } sorting = true onSortStart?() - let sortStartTime = Date() let splatCount = splatBuffer.count - + let cameraWorldForward = cameraWorldForward let cameraWorldPosition = cameraWorldPosition + +// // For benchmark. +// guard splatCount > 0 else { +// sorting = false +// let elapsed: TimeInterval = 0 +// Self.log.info("Sort time (\(useGPU ? "GPU" : "CPU")): \(elapsed) seconds") +// onSortComplete?(elapsed) +// return +// } + + if useGPU { + Task(priority: .high) { +// let startTime = Date() + + // Allocate a GPU buffer for storing distances. + guard let distanceBuffer = device.makeBuffer( + length: MemoryLayout.size * splatCount, + options: .storageModeShared + ) else { + Self.log.error("Failed to create distance buffer.") + self.sorting = false + return + } - Task(priority: .high) { - defer { - sorting = false - onSortComplete?(-sortStartTime.timeIntervalSinceNow) - } - - if orderAndDepthTempSort.count != splatCount { - orderAndDepthTempSort = Array(repeating: SplatIndexAndDepth(index: .max, depth: 0), count: splatCount) - } + // Compute distances on CPU then copy to distanceBuffer. + let distancePtr = distanceBuffer.contents().bindMemory(to: Float.self, capacity: splatCount) + if Constants.sortByDistance { + for i in 0 ..< splatCount { + let splatPos = splatBuffer.values[i].position.simd + distancePtr[i] = (splatPos - cameraWorldPosition).lengthSquared + } + } else { + for i in 0 ..< splatCount { + let splatPos = splatBuffer.values[i].position.simd + distancePtr[i] = dot(splatPos, cameraWorldForward) + } + } + + + // Allocate a GPU buffer for the ArgSort output indices + guard let indexOutputBuffer = device.makeBuffer( + length: MemoryLayout.size * splatCount, + options: .storageModeShared + ) else { + Self.log.error("Failed to create output indices buffer.") + self.sorting = false + return + } - if Constants.sortByDistance { - for i in 0.. $1.depth } + if Constants.sortByDistance { + for i in 0 ..< splatCount { + orderAndDepthTempSort[i].index = UInt32(i) + let splatPos = splatBuffer.values[i].position.simd + orderAndDepthTempSort[i].depth = (splatPos - cameraWorldPosition).lengthSquared + } + } else { + for i in 0 ..< splatCount { + orderAndDepthTempSort[i].index = UInt32(i) + let splatPos = splatBuffer.values[i].position.simd + orderAndDepthTempSort[i].depth = dot(splatPos, cameraWorldForward) + } + } - do { - try splatBufferPrime.setCapacity(splatCount) - splatBufferPrime.count = 0 - for newIndex in 0.. $1.depth } + + do { + try splatBufferPrime.setCapacity(splatCount) + splatBufferPrime.count = 0 + for newIndex in 0..