Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions MetalSplatter/Sources/MPSArgSort.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Original source: https://gist.github.com/kemchenj/26e1dad40e5b89de2828bad36c81302f
// Assessed Feb 2, 2025.
import MetalPerformanceShaders
import MetalPerformanceShadersGraph

public class MPSArgSort {
private let dataType: MPSDataType
private let graph: MPSGraph
private let graphExecutable: MPSGraphExecutable
private let inputTensor: MPSGraphTensor
private let outputTensor: MPSGraphTensor

init(dataType: MPSDataType, descending: Bool = false) {
self.dataType = dataType

let graph = MPSGraph()
let inputTensor = graph.placeholder(shape: nil, dataType: dataType, name: nil)
let outputTensor = graph.argSort(inputTensor, axis: 0, descending: descending, name: nil)

self.graph = graph
self.inputTensor = inputTensor
self.outputTensor = outputTensor
self.graphExecutable = autoreleasepool {
let compilationDescriptor = MPSGraphCompilationDescriptor()
compilationDescriptor.waitForCompilationCompletion = true
compilationDescriptor.disableTypeInference()
return graph.compile(with: nil,
feeds: [inputTensor : MPSGraphShapedType(shape: nil, dataType: dataType)],
targetTensors: [outputTensor],
targetOperations: nil,
compilationDescriptor: compilationDescriptor)
}
}

func callAsFunction(
commandQueue: any MTLCommandQueue,
input: any MTLBuffer,
output: any MTLBuffer,
count: Int
) {
autoreleasepool {
let commandBuffer = commandQueue.makeCommandBuffer()!
callAsFunction(commandBuffer: commandBuffer,
input: input,
output: output,
count: count)
assert(commandBuffer.error == nil)
assert(commandBuffer.status == .completed)
}
}

private func callAsFunction(
commandBuffer: any MTLCommandBuffer,
input: any MTLBuffer,
output: any MTLBuffer,
count: Int
) {
let shape: [NSNumber] = [count as NSNumber]
let inputData = MPSGraphTensorData(input, shape: shape, dataType: dataType)
let outputData = MPSGraphTensorData(output, shape: shape, dataType: .int32)
let executionDescriptor = MPSGraphExecutableExecutionDescriptor()
executionDescriptor.waitUntilCompleted = true
graphExecutable.encode(to: MPSCommandBuffer(commandBuffer: commandBuffer),
inputs: [inputData],
results: [outputData],
executionDescriptor: executionDescriptor)
}
}
155 changes: 122 additions & 33 deletions MetalSplatter/Sources/SplatRenderer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public class SplatRenderer {
private static let log =
Logger(subsystem: Bundle.module.bundleIdentifier!,
category: "SplatRenderer")


private var computeDepthsPipelineState: MTLComputePipelineState?

public struct ViewportDescriptor {
public var viewport: MTLViewport
public var projectionMatrix: simd_float4x4
Expand Down Expand Up @@ -570,54 +572,141 @@ public class SplatRenderer {
}

// Sort splatBuffer (read-only), storing the results in splatBuffer (write-only) then swap splatBuffer and splatBufferPrime
public func resort() {
public func resort(useGPU: Bool = true) {
guard !sorting else { return }
sorting = true
onSortStart?()
let sortStartTime = Date()

let splatCount = splatBuffer.count

let cameraWorldForward = cameraWorldForward
let cameraWorldPosition = cameraWorldPosition

// // For benchmark.
// guard splatCount > 0 else {
// sorting = false
// let elapsed: TimeInterval = 0
// Self.log.info("Sort time (\(useGPU ? "GPU" : "CPU")): \(elapsed) seconds")
// onSortComplete?(elapsed)
// return
// }

if useGPU {
Task(priority: .high) {
// let startTime = Date()

// Allocate a GPU buffer for storing distances.
guard let distanceBuffer = device.makeBuffer(
length: MemoryLayout<Float>.size * splatCount,
options: .storageModeShared
) else {
Self.log.error("Failed to create distance buffer.")
self.sorting = false
return
}

Task(priority: .high) {
defer {
sorting = false
onSortComplete?(-sortStartTime.timeIntervalSinceNow)
}

if orderAndDepthTempSort.count != splatCount {
orderAndDepthTempSort = Array(repeating: SplatIndexAndDepth(index: .max, depth: 0), count: splatCount)
}
// Compute distances on CPU then copy to distanceBuffer.
let distancePtr = distanceBuffer.contents().bindMemory(to: Float.self, capacity: splatCount)
if Constants.sortByDistance {
for i in 0 ..< splatCount {
let splatPos = splatBuffer.values[i].position.simd
distancePtr[i] = (splatPos - cameraWorldPosition).lengthSquared
}
} else {
for i in 0 ..< splatCount {
let splatPos = splatBuffer.values[i].position.simd
distancePtr[i] = dot(splatPos, cameraWorldForward)
}
}


// Allocate a GPU buffer for the ArgSort output indices
guard let indexOutputBuffer = device.makeBuffer(
length: MemoryLayout<Int32>.size * splatCount,
options: .storageModeShared
) else {
Self.log.error("Failed to create output indices buffer.")
self.sorting = false
return
}

if Constants.sortByDistance {
for i in 0..<splatCount {
orderAndDepthTempSort[i].index = UInt32(i)
let splatPosition = splatBuffer.values[i].position.simd
orderAndDepthTempSort[i].depth = (splatPosition - cameraWorldPosition).lengthSquared
// Create command queue for MPSArgSort.
guard let commandQueue = device.makeCommandQueue() else {
Self.log.error("Failed to create command queue for MPSArgSort.")
self.sorting = false
return
}
} else {
for i in 0..<splatCount {
orderAndDepthTempSort[i].index = UInt32(i)
let splatPosition = splatBuffer.values[i].position.simd
orderAndDepthTempSort[i].depth = dot(splatPosition, cameraWorldForward)

// Run argsort, in decending order.
let argSort = MPSArgSort(dataType: .float32, descending: true)
argSort(commandQueue: commandQueue,
input: distanceBuffer,
output: indexOutputBuffer,
count: splatCount)

// Read back the sorted indices and reorder splats on the CPU.
let sortedIndices = indexOutputBuffer.contents().bindMemory(to: Int32.self, capacity: splatCount)

do {
try self.splatBufferPrime.setCapacity(splatCount)
self.splatBufferPrime.count = 0
for newIndex in 0 ..< splatCount {
let oldIndex = Int(sortedIndices[newIndex])
splatBufferPrime.append(splatBuffer, fromIndex: oldIndex)
}
swap(&splatBuffer, &splatBufferPrime)
} catch {
Self.log.error("Failed to set capacity or reorder: \(error)")
}

// let elapsed = Date().timeIntervalSince(startTime)
// Self.log.info("Sort time (GPU): \(elapsed) seconds")
// self.onSortComplete?(elapsed)
self.sorting = false
}
} else {
Task(priority: .high) {
// let cpuStart = Date()
if orderAndDepthTempSort.count != splatCount {
orderAndDepthTempSort = Array(
repeating: SplatIndexAndDepth(index: .max, depth: 0),
count: splatCount
)
}

orderAndDepthTempSort.sort { $0.depth > $1.depth }
if Constants.sortByDistance {
for i in 0 ..< splatCount {
orderAndDepthTempSort[i].index = UInt32(i)
let splatPos = splatBuffer.values[i].position.simd
orderAndDepthTempSort[i].depth = (splatPos - cameraWorldPosition).lengthSquared
}
} else {
for i in 0 ..< splatCount {
orderAndDepthTempSort[i].index = UInt32(i)
let splatPos = splatBuffer.values[i].position.simd
orderAndDepthTempSort[i].depth = dot(splatPos, cameraWorldForward)
}
}

do {
try splatBufferPrime.setCapacity(splatCount)
splatBufferPrime.count = 0
for newIndex in 0..<orderAndDepthTempSort.count {
let oldIndex = Int(orderAndDepthTempSort[newIndex].index)
splatBufferPrime.append(splatBuffer, fromIndex: oldIndex)
orderAndDepthTempSort.sort { $0.depth > $1.depth }

do {
try splatBufferPrime.setCapacity(splatCount)
splatBufferPrime.count = 0
for newIndex in 0..<orderAndDepthTempSort.count {
let oldIndex = Int(orderAndDepthTempSort[newIndex].index)
splatBufferPrime.append(splatBuffer, fromIndex: oldIndex)
}

swap(&splatBuffer, &splatBufferPrime)
} catch {
Self.log.error("Failed to set capacity or reorder: \(error)")
}

swap(&splatBuffer, &splatBufferPrime)
} catch {
// TODO: report error
// let elapsedCPU = -cpuStart.timeIntervalSinceNow
// Self.log.info("Sort time (CPU): \(elapsedCPU) seconds")
// onSortComplete?(elapsedCPU)
self.sorting = false
}
}
}
Expand Down