From d2006298764c1b88ed95b1e90476158779e25703 Mon Sep 17 00:00:00 2001 From: Szymon Rodziewicz Date: Thu, 22 Jan 2026 18:01:40 +0100 Subject: [PATCH] LLama GPU Implementation --- .gitignore | 4 +- build.sbt | 7 +- .../io/computenode/cyfra/spirv/Context.scala | 11 +- .../io/computenode/cyfra/spirv/Opcodes.scala | 58 ++ .../cyfra/spirv/SpirvConstants.scala | 9 +- .../computenode/cyfra/spirv/SpirvTypes.scala | 42 +- .../cyfra/spirv/compilers/DSLCompiler.scala | 99 +- .../spirv/compilers/ExpressionCompiler.scala | 187 +++- .../cyfra/spirv/compilers/GIOCompiler.scala | 328 +++++-- .../cyfra/spirv/compilers/GSeqCompiler.scala | 5 +- .../compilers/SpirvProgramCompiler.scala | 154 +++- .../io/computenode/cyfra/dsl/Expression.scala | 35 + .../io/computenode/cyfra/dsl/Value.scala | 10 + .../cyfra/dsl/algebra/ScalarAlgebra.scala | 16 +- .../cyfra/dsl/algebra/VectorAlgebra.scala | 4 + .../cyfra/dsl/binding/GShared.scala | 39 + .../cyfra/dsl/binding/ReadShared.scala | 11 + .../cyfra/dsl/binding/WriteShared.scala | 14 + .../cyfra/dsl/collections/GSeq.scala | 15 +- .../io/computenode/cyfra/dsl/gio/GIO.scala | 210 ++++- .../cyfra/dsl/library/Functions.scala | 15 + .../e2e/dsl/WorkgroupPrimitivesE2eTest.scala | 280 ++++++ .../samples/examples/GFunctionExamples.scala | 31 + cyfra-llama/compare_incremental.py | 51 ++ cyfra-llama/compare_logits.py | 70 ++ cyfra-llama/compare_with_llama_cpp.py | 71 ++ .../io/computenode/cyfra/llama/Runner.scala | 281 ++++++ .../cyfra/llama/gguf/Dequantize.scala | 204 +++++ .../cyfra/llama/gguf/GGUFReader.scala | 477 ++++++++++ .../cyfra/llama/inference/CPUInference.scala | 232 +++++ .../llama/inference/LlamaInference.scala | 186 ++++ .../cyfra/llama/model/LlamaConfig.scala | 119 +++ .../cyfra/llama/model/LlamaModel.scala | 89 ++ .../llama/pipeline/LlamaF16Pipeline.scala | 865 ++++++++++++++++++ .../llama/pipeline/LlamaF32Pipeline.scala | 848 +++++++++++++++++ .../cyfra/llama/pipeline/LlamaPipeline.scala | 102 +++ .../llama/programs/f16/F16CopyProgram.scala | 30 + .../programs/f16/F16EmbeddingProgram.scala | 43 + .../f16/F16FusedGateUpSwiGLUProgram.scala | 129 +++ .../f16/F16FusedKVCacheWriteProgram.scala | 105 +++ .../f16/F16FusedQKVMatmulProgram.scala | 149 +++ .../programs/f16/F16FusedRoPEProgram.scala | 124 +++ .../llama/programs/f16/F16KVCacheWriteK.scala | 87 ++ .../llama/programs/f16/F16KVCacheWriteV.scala | 83 ++ .../programs/f16/F16KVCachedAttention.scala | 174 ++++ .../f16/F16MatmulResidualAddProgram.scala | 157 ++++ .../f16/F16MatmulVecHybridProgram.scala | 89 ++ .../programs/f16/F16OutputVec4Program.scala | 80 ++ .../programs/f16/F16RMSNormProgram.scala | 83 ++ .../programs/f16/F16ResidualAddProgram.scala | 33 + .../llama/programs/f16/F16RoPEProgram.scala | 89 ++ .../llama/programs/f16/F16SwiGLUProgram.scala | 39 + .../cyfra/llama/programs/f16/package.scala | 4 + .../llama/programs/f32/EmbeddingProgram.scala | 49 + .../llama/programs/f32/KVCacheWriteK.scala | 87 ++ .../llama/programs/f32/KVCacheWriteV.scala | 83 ++ .../programs/f32/KVCachedAttention.scala | 172 ++++ .../programs/f32/Q4KMatmulVecProgram.scala | 342 +++++++ .../programs/f32/Q6KMatmulVecProgram.scala | 405 ++++++++ .../llama/programs/f32/RMSNormProgram.scala | 88 ++ .../programs/f32/ResidualAddProgram.scala | 67 ++ .../llama/programs/f32/RoPEProgram.scala | 94 ++ .../llama/programs/f32/SwiGLUProgram.scala | 77 ++ .../programs/f32/TiledMatmulVecProgram.scala | 85 ++ .../cyfra/llama/programs/f32/package.scala | 21 + .../cyfra/llama/programs/package.scala | 16 + .../llama/tokenizer/LlamaTokenizer.scala | 121 +++ .../computenode/cyfra/llama/util/Logger.scala | 12 + .../cyfra/llama/DequantizationTest.scala | 278 ++++++ .../cyfra/llama/DirectBenchmarkTest.scala | 103 +++ .../cyfra/llama/F16KVCacheTest.scala | 111 +++ .../io/computenode/cyfra/llama/GGUFTest.scala | 90 ++ .../cyfra/llama/ShaderDumpTest.scala | 217 +++++ .../cyfra/runtime/VkCyfraRuntime.scala | 8 +- 74 files changed, 9066 insertions(+), 137 deletions(-) create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala create mode 100644 cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala create mode 100644 cyfra-llama/compare_incremental.py create mode 100644 cyfra-llama/compare_logits.py create mode 100644 cyfra-llama/compare_with_llama_cpp.py create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/CPUInference.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF32Pipeline.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16CopyProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedGateUpSwiGLUProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteK.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteV.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCachedAttention.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulResidualAddProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RoPEProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/EmbeddingProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteK.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteV.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCachedAttention.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q4KMatmulVecProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q6KMatmulVecProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RMSNormProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/ResidualAddProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RoPEProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/SwiGLUProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/TiledMatmulVecProgram.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/package.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala create mode 100644 cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala create mode 100644 cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala create mode 100644 cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala create mode 100644 cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala create mode 100644 cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala create mode 100644 cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala diff --git a/.gitignore b/.gitignore index 5e834f14..6aed21fe 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,6 @@ metals.sbt smoke julia - +# Model weight files +*.gguf +.lwjgl/ diff --git a/build.sbt b/build.sbt index eaa97261..a8c185c0 100644 --- a/build.sbt +++ b/build.sbt @@ -153,10 +153,15 @@ lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(publish / skip := true) .dependsOn(runtime, fs2interop, foton) +lazy val llama = (project in file("cyfra-llama")) + .settings(commonSettings, runnerSettings) + .settings(publish / skip := true) + .dependsOn(runtime, dsl, core, utility) + lazy val root = (project in file(".")) .settings(name := "Cyfra") .settings(publish / skip := true) - .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, fluids, analytics, utility, spirvTools, vscode) + .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, fluids, analytics, utility, spirvTools, vscode, llama) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala index 96490071..bf3235bb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock +import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.{ArrayBufferBlock, SharedBlock} import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -13,15 +13,24 @@ private[cyfra] case class Context( funPointerTypeMap: Map[Int, Int] = Map(), uniformPointerMap: Map[Int, Int] = Map(), inputPointerMap: Map[Int, Int] = Map(), + workgroupPointerMap: Map[Int, Int] = Map(), funcTypeMap: Map[(LightTypeTag, List[LightTypeTag]), Int] = Map(), voidTypeRef: Int = -1, voidFuncTypeRef: Int = -1, workerIndexRef: Int = -1, + localInvocationIndexRef: Int = -1, + localInvocationIdRef: Int = -1, + workgroupIdRef: Int = -1, + numWorkgroupsRef: Int = -1, + subgroupIdRef: Int = -1, + subgroupLocalInvocationIdRef: Int = -1, + subgroupSizeRef: Int = -1, uniformVarRefs: Map[GUniform[?], Int] = Map.empty, bindingToStructType: Map[Int, Int] = Map.empty, constRefs: Map[(Tag[?], Any), Int] = Map(), exprRefs: Map[Int, Int] = Map(), bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(), + sharedVarRefs: Map[Int, SharedBlock] = Map(), nextResultId: Int = HEADER_REFS_TOP, nextBinding: Int = 0, exprNames: Map[Int, String] = Map(), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala index 1f8c4cb6..3e9c9ffd 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala @@ -516,6 +516,20 @@ private[cyfra] object Opcodes: val Reduce = Code("Reduce", 0) val InclusiveScan = Code("InclusiveScan", 1) val ExclusiveScan = Code("ExclusiveScan", 2) + val ClusteredReduce = Code("ClusteredReduce", 3) + + object MemorySemantics: + val None = Code("None", 0x0) + val Acquire = Code("Acquire", 0x2) + val Release = Code("Release", 0x4) + val AcquireRelease = Code("AcquireRelease", 0x8) + val SequentiallyConsistent = Code("SequentiallyConsistent", 0x10) + val UniformMemory = Code("UniformMemory", 0x40) + val SubgroupMemory = Code("SubgroupMemory", 0x80) + val WorkgroupMemory = Code("WorkgroupMemory", 0x100) + val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 0x200) + val AtomicCounterMemory = Code("AtomicCounterMemory", 0x400) + val ImageMemory = Code("ImageMemory", 0x800) object KernelEnqueueFlags: val NoWait = Code("NoWait", 0) @@ -589,6 +603,14 @@ private[cyfra] object Opcodes: val SubgroupDispatch = Code("SubgroupDispatch", 58) val NamedBarrier = Code("NamedBarrier", 59) val PipeStorage = Code("PipeStorage", 60) + val GroupNonUniform = Code("GroupNonUniform", 61) + val GroupNonUniformVote = Code("GroupNonUniformVote", 62) + val GroupNonUniformArithmetic = Code("GroupNonUniformArithmetic", 63) + val GroupNonUniformBallot = Code("GroupNonUniformBallot", 64) + val GroupNonUniformShuffle = Code("GroupNonUniformShuffle", 65) + val GroupNonUniformShuffleRelative = Code("GroupNonUniformShuffleRelative", 66) + val GroupNonUniformClustered = Code("GroupNonUniformClustered", 67) + val GroupNonUniformQuad = Code("GroupNonUniformQuad", 68) val SubgroupBallotKHR = Code("SubgroupBallotKHR", 4423) val DrawParameters = Code("DrawParameters", 4427) val SubgroupVoteKHR = Code("SubgroupVoteKHR", 4431) @@ -949,6 +971,42 @@ private[cyfra] object Opcodes: val OpSubgroupImageBlockReadINTEL = Code("OpSubgroupImageBlockReadINTEL", 5577) val OpSubgroupImageBlockWriteINTEL = Code("OpSubgroupImageBlockWriteINTEL", 5578) + // GroupNonUniform operations (Vulkan 1.1+) + val OpGroupNonUniformElect = Code("OpGroupNonUniformElect", 333) + val OpGroupNonUniformAll = Code("OpGroupNonUniformAll", 334) + val OpGroupNonUniformAny = Code("OpGroupNonUniformAny", 335) + val OpGroupNonUniformAllEqual = Code("OpGroupNonUniformAllEqual", 336) + val OpGroupNonUniformBroadcast = Code("OpGroupNonUniformBroadcast", 337) + val OpGroupNonUniformBroadcastFirst = Code("OpGroupNonUniformBroadcastFirst", 338) + val OpGroupNonUniformBallot = Code("OpGroupNonUniformBallot", 339) + val OpGroupNonUniformInverseBallot = Code("OpGroupNonUniformInverseBallot", 340) + val OpGroupNonUniformBallotBitExtract = Code("OpGroupNonUniformBallotBitExtract", 341) + val OpGroupNonUniformBallotBitCount = Code("OpGroupNonUniformBallotBitCount", 342) + val OpGroupNonUniformBallotFindLSB = Code("OpGroupNonUniformBallotFindLSB", 343) + val OpGroupNonUniformBallotFindMSB = Code("OpGroupNonUniformBallotFindMSB", 344) + val OpGroupNonUniformShuffle = Code("OpGroupNonUniformShuffle", 345) + val OpGroupNonUniformShuffleXor = Code("OpGroupNonUniformShuffleXor", 346) + val OpGroupNonUniformShuffleUp = Code("OpGroupNonUniformShuffleUp", 347) + val OpGroupNonUniformShuffleDown = Code("OpGroupNonUniformShuffleDown", 348) + val OpGroupNonUniformIAdd = Code("OpGroupNonUniformIAdd", 349) + val OpGroupNonUniformFAdd = Code("OpGroupNonUniformFAdd", 350) + val OpGroupNonUniformIMul = Code("OpGroupNonUniformIMul", 351) + val OpGroupNonUniformFMul = Code("OpGroupNonUniformFMul", 352) + val OpGroupNonUniformSMin = Code("OpGroupNonUniformSMin", 353) + val OpGroupNonUniformUMin = Code("OpGroupNonUniformUMin", 354) + val OpGroupNonUniformFMin = Code("OpGroupNonUniformFMin", 355) + val OpGroupNonUniformSMax = Code("OpGroupNonUniformSMax", 356) + val OpGroupNonUniformUMax = Code("OpGroupNonUniformUMax", 357) + val OpGroupNonUniformFMax = Code("OpGroupNonUniformFMax", 358) + val OpGroupNonUniformBitwiseAnd = Code("OpGroupNonUniformBitwiseAnd", 359) + val OpGroupNonUniformBitwiseOr = Code("OpGroupNonUniformBitwiseOr", 360) + val OpGroupNonUniformBitwiseXor = Code("OpGroupNonUniformBitwiseXor", 361) + val OpGroupNonUniformLogicalAnd = Code("OpGroupNonUniformLogicalAnd", 362) + val OpGroupNonUniformLogicalOr = Code("OpGroupNonUniformLogicalOr", 363) + val OpGroupNonUniformLogicalXor = Code("OpGroupNonUniformLogicalXor", 364) + val OpGroupNonUniformQuadBroadcast = Code("OpGroupNonUniformQuadBroadcast", 365) + val OpGroupNonUniformQuadSwap = Code("OpGroupNonUniformQuadSwap", 366) + object GlslOp: val Round = Code("Round", 1) val RoundEven = Code("RoundEven", 2) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala index ec3c4d0b..b6a98052 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala @@ -17,5 +17,12 @@ private[cyfra] object SpirvConstants: val GL_GLOBAL_INVOCATION_ID_REF = 5 val GL_WORKGROUP_SIZE_REF = 6 val DEBUG_PRINTF_REF = 7 + val GL_LOCAL_INVOCATION_ID_REF = 8 + val GL_LOCAL_INVOCATION_INDEX_REF = 9 + val GL_WORKGROUP_ID_REF = 10 + val GL_NUM_WORKGROUPS_REF = 11 + val GL_SUBGROUP_ID_REF = 12 + val GL_SUBGROUP_LOCAL_INVOCATION_ID_REF = 13 + val GL_SUBGROUP_SIZE_REF = 14 - val HEADER_REFS_TOP = 8 + val HEADER_REFS_TOP = 15 diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala index 7adeb972..c4fac1eb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala @@ -10,6 +10,7 @@ private[cyfra] object SpirvTypes: val Int32Tag = summon[Tag[Int32]] val UInt32Tag = summon[Tag[UInt32]] + val Float16Tag = summon[Tag[Float16]] val Float32Tag = summon[Tag[Float32]] val GBooleanTag = summon[Tag[GBoolean]] val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs @@ -22,6 +23,7 @@ private[cyfra] object SpirvTypes: val LInt32Tag = Int32Tag.tag val LUInt32Tag = UInt32Tag.tag + val LFloat16Tag = Float16Tag.tag val LFloat32Tag = Float32Tag.tag val LGBooleanTag = GBooleanTag.tag val LVec2TagWithoutArgs = Vec2TagWithoutArgs @@ -36,9 +38,38 @@ private[cyfra] object SpirvTypes: type Vec3C[T <: Value] = Vec3[T] type Vec4C[T <: Value] = Vec4[T] + /** Convert Float32 to Float16 (half precision) bits. + * Uses round-to-nearest-even rounding mode. + */ + def floatToFloat16(f: Float): Int = { + val bits = java.lang.Float.floatToIntBits(f) + val sign = (bits >>> 16) & 0x8000 + val exponent = ((bits >>> 23) & 0xFF) - 127 + 15 + val mantissa = bits & 0x007FFFFF + + if (exponent <= 0) { + // Denormalized or zero + if (exponent < -10) { + sign // Zero + } else { + // Denormalized + val m = mantissa | 0x00800000 + val shifted = m >>> (1 - exponent) + sign | (shifted >>> 13) + } + } else if (exponent >= 31) { + // Infinity or NaN + sign | 0x7C00 | (if (mantissa != 0) 0x200 else 0) + } else { + // Normalized + sign | (exponent << 10) | (mantissa >>> 13) + } + } + def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) + case Float16Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(16))) case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) @@ -50,6 +81,7 @@ private[cyfra] object SpirvTypes: def typeStride(tag: LightTypeTag): Int = tag match case LInt32Tag => 4 case LUInt32Tag => 4 + case LFloat16Tag => 2 case LFloat32Tag => 4 case LGBooleanTag => 4 case v if v <:< LVecTag => @@ -63,6 +95,14 @@ private[cyfra] object SpirvTypes: IntWord(value.asInstanceOf[Int]) case t if t == UInt32Tag => IntWord(value.asInstanceOf[Int]) + case t if t == Float16Tag => + val fl = value match + case fl: Float => fl + case dl: Double => dl.toFloat + case il: Int => il.toFloat + // Convert Float32 to Float16 (half precision) + val f16Bits = floatToFloat16(fl) + Word(intToBytes(f16Bits & 0xFFFF).reverse.toArray) case t if t == Float32Tag => val fl = value match case fl: Float => fl @@ -71,7 +111,7 @@ private[cyfra] object SpirvTypes: Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = - val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) + val basicTypes = List(Int32Tag, Float16Tag, Float32Tag, UInt32Tag, GBooleanTag) (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => val typeDefIndex = ctx.nextResultId val code = List( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala index 8bdafb24..ed4d15d8 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.* import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.Value.Scalar -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform} +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GShared, GUniform, ReadShared, WriteBuffer, WriteShared, WriteUniform} import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.GStruct.* import io.computenode.cyfra.dsl.struct.GStructSchema @@ -34,7 +34,7 @@ private[cyfra] object DSLCompiler: getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached) case GIO.FlatMap(v, n) :: tail => getAllExprsFlattened(v :: n :: tail, acc, visitDetached) - case GIO.Repeat(n, gio) :: tail => + case GIO.Repeat(n, gio, _) :: tail => val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached) case WriteBuffer(_, index, value) :: tail => @@ -47,6 +47,16 @@ private[cyfra] object DSLCompiler: case GIO.Printf(_, args*) :: tail => val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached) + case GIO.WorkgroupBarrier :: tail => + getAllExprsFlattened(tail, acc, visitDetached) + case WriteShared(_, index, value) :: tail => + val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached) + val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) + getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached) + case GIO.FoldRepeat(n, init, body, _, _) :: tail => + val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) + val initAllExprs = getAllExprsFlattened(init.tree, visitDetached) + getAllExprsFlattened(body :: tail, nAllExprs ::: initAllExprs ::: acc, visitDetached) // TODO: Not traverse same fn scopes for each fn call private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = @@ -71,22 +81,82 @@ private[cyfra] object DSLCompiler: allScopesCache(root.treeid) = result result + private def getAllShared(pending: List[GIO[?]], acc: Map[Int, GShared[?]]): Map[Int, GShared[?]] = + pending match + case Nil => acc + case GIO.FlatMap(v, n) :: tail => + getAllShared(v :: n :: tail, acc) + case GIO.Repeat(_, gio, _) :: tail => + getAllShared(gio :: tail, acc) + case GIO.FoldRepeat(_, _, gio, _, _) :: tail => + getAllShared(gio :: tail, acc) + case WriteShared(buffer, _, _) :: tail => + val impl = buffer.asInstanceOf[GShared.GSharedImpl[?]] + getAllShared(tail, acc + (impl.sharedId -> buffer)) + case _ :: tail => getAllShared(tail, acc) + + private def getAllSharedFromExprs(exprs: List[E[?]], acc: Map[Int, GShared[?]]): Map[Int, GShared[?]] = + exprs.foldLeft(acc): + case (a, ReadShared(buffer, _)) => + val impl = buffer.asInstanceOf[GShared.GSharedImpl[?]] + a + (impl.sharedId -> buffer) + case (a, _) => a + + private def createSharedVariables(sharedBuffers: Map[Int, GShared[?]], ctx: Context): (List[Words], Context) = + sharedBuffers.foldLeft((List.empty[Words], ctx)): + case ((insnsAcc, currentCtx), (sharedId, buffer)) => + val elementTypeRef = currentCtx.valueTypeMap(buffer.tag.tag) + val arraySizeConstRef = currentCtx.constRefs.getOrElse( + (Int32Tag, buffer.size), + throw new IllegalStateException(s"Missing constant for shared array size ${buffer.size}"), + ) + + // SPIR-V shared memory structure: + // 1. Array type: OpTypeArray %arrayType %elementType %size + // 2. Pointer to array: OpTypePointer %ptrArrayType Workgroup %arrayType + // 3. Variable: OpVariable %ptrArrayType %var Workgroup + // 4. Pointer to element: OpTypePointer %ptrElemType Workgroup %elementType (for OpAccessChain) + val arrayTypeRef = currentCtx.nextResultId + val ptrArrayTypeRef = currentCtx.nextResultId + 1 + val varRef = currentCtx.nextResultId + 2 + val ptrElemTypeRef = currentCtx.nextResultId + 3 + + val insns = List( + Instruction(Op.OpTypeArray, List(ResultRef(arrayTypeRef), ResultRef(elementTypeRef), ResultRef(arraySizeConstRef))), + Instruction(Op.OpTypePointer, List(ResultRef(ptrArrayTypeRef), StorageClass.Workgroup, ResultRef(arrayTypeRef))), + Instruction(Op.OpVariable, List(ResultRef(ptrArrayTypeRef), ResultRef(varRef), StorageClass.Workgroup)), + Instruction(Op.OpTypePointer, List(ResultRef(ptrElemTypeRef), StorageClass.Workgroup, ResultRef(elementTypeRef))), + ) + + val block = SharedBlock(arrayTypeRef, varRef, ptrElemTypeRef) + val newCtx = currentCtx.copy( + nextResultId = currentCtx.nextResultId + 4, + sharedVarRefs = currentCtx.sharedVarRefs + (sharedId -> block), + workgroupPointerMap = currentCtx.workgroupPointerMap + (elementTypeRef -> ptrElemTypeRef), + ) + (insnsAcc ::: insns, newCtx) + // So far only used for printf private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] = pending match case Nil => acc case GIO.FlatMap(v, n) :: tail => getAllStrings(v :: n :: tail, acc) - case GIO.Repeat(_, gio) :: tail => + case GIO.Repeat(_, gio, _) :: tail => getAllStrings(gio :: tail, acc) case GIO.Printf(format, _*) :: tail => getAllStrings(tail, acc + format) case _ :: tail => getAllStrings(tail, acc) - def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer = + def compile(bodyIo: GIO[?], bindings: List[GBinding[?]], workgroupSize: (Int, Int, Int) = (256, 1, 1)): ByteBuffer = val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true) val typesInCode = allExprs.map(_.tag).distinct - val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct + + val sharedFromGio = getAllShared(List(bodyIo), Map.empty) + val sharedFromExprs = getAllSharedFromExprs(allExprs, sharedFromGio) + val sharedTypes = sharedFromExprs.values.map(_.tag).toList + + val allTypes = (typesInCode ::: bindings.map(_.tag) ::: sharedTypes).distinct def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag) val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) val allStrings = getAllStrings(List(bodyIo), Set.empty) @@ -108,17 +178,28 @@ private[cyfra] object DSLCompiler: val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx) val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext) val blockNames = getBlockNames(uniformContext, uniforms) - val (inputDefs, inputContext) = createInvocationId(uniformStructContext) + val (inputDefs, inputContext) = createInvocationId(uniformStructContext, workgroupSize) + + val sharedSizeConsts = sharedFromExprs.values.map(s => (Int32Tag, s.size)).toList val (constDefs, constCtx) = defineConstants(allExprs, inputContext) - val (varDefs, varCtx) = defineVarNames(constCtx) + val (sharedConstDefs, constCtxWithShared) = sharedSizeConsts.foldLeft((List.empty[Words], constCtx)): + case ((insnsAcc, ctx), const) if ctx.constRefs.contains(const) => (insnsAcc, ctx) + case ((insnsAcc, ctx), const) => + val insn = Instruction(Op.OpConstant, List(ResultRef(ctx.valueTypeMap(const._1.tag)), ResultRef(ctx.nextResultId), IntWord(const._2))) + val newCtx = ctx.copy(constRefs = ctx.constRefs + (const -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (insnsAcc :+ insn, newCtx) + + val (sharedDefs, ctxWithShared) = createSharedVariables(sharedFromExprs, constCtxWithShared) + + val (varDefs, varCtx) = defineVarNames(ctxWithShared) val (main, ctxAfterMain) = compileMain(bodyIo, varCtx) val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain) val nameDecorations = getNameDecorations(ctxWithFnDefs) val code: List[Words] = - SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: + SpirvProgramCompiler.headers(workgroupSize) ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: - constDefs ::: varDefs ::: main ::: fnDefs + constDefs ::: sharedConstDefs ::: sharedDefs ::: varDefs ::: main ::: fnDefs val fullCode = code.map: case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala index 6e859bd3..7c6833c9 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala @@ -29,6 +29,33 @@ private[cyfra] object ExpressionCompiler: case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) case _: Mod[?] => (Op.OpSMod, Op.OpFMod) + private def compileSubgroupOp( + expr: E[?], + value: Value.Scalar, + op: SubgroupOp, + spirvOp: Code, + ctx: Context, + ): (List[Instruction], Context) = + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val groupOpCode = op match + case SubgroupOp.Reduce => GroupOperation.Reduce + case SubgroupOp.InclusiveScan => GroupOperation.InclusiveScan + case SubgroupOp.ExclusiveScan => GroupOperation.ExclusiveScan + val instructions = List( + Instruction( + spirvOp, + List( + ResultRef(ctx.valueTypeMap(expr.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + groupOpCode, + ResultRef(ctx.exprRefs(value.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) @@ -52,12 +79,18 @@ private[cyfra] object ExpressionCompiler: val tpe = cexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) val tfOpcode = (cexpr.fromTag, cexpr) match - case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF - case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF - case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS - case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU - case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast - case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast + case (from, _: ToFloat16[?]) if from.tag =:= Float32Tag.tag => Op.OpFConvert + case (from, _: ToFloat16[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF + case (from, _: ToFloat16[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF + case (from, _: ToFloat32[?]) if from.tag =:= Float16Tag.tag => Op.OpFConvert + case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF + case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF + case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS + case (from, _: ToInt32[?]) if from.tag =:= Float16Tag.tag => Op.OpConvertFToS + case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU + case (from, _: ToUInt32[?]) if from.tag =:= Float16Tag.tag => Op.OpConvertFToU + case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast + case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) @@ -109,12 +142,136 @@ private[cyfra] object ExpressionCompiler: case w @ InvocationId => (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef))) + case w @ LocalInvocationIndex => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.localInvocationIndexRef))) + + case w @ LocalInvocationId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.localInvocationIdRef))) + + case w @ WorkgroupId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workgroupIdRef))) + + case w @ NumWorkgroups => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.numWorkgroupsRef))) + + case w @ SubgroupId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.subgroupIdRef))) + + case w @ SubgroupLocalInvocationId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.subgroupLocalInvocationIdRef))) + + case w @ SubgroupSize => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.subgroupSizeRef))) + + case sg @ SubgroupAddI(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformIAdd, ctx) + + case sg @ SubgroupAddF(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFAdd, ctx) + + case sg @ SubgroupAddF16(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFAdd, ctx) + + case sg @ SubgroupMinI(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformSMin, ctx) + + case sg @ SubgroupMinF(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMin, ctx) + + case sg @ SubgroupMinF16(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMin, ctx) + + case sg @ SubgroupMaxI(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformSMax, ctx) + + case sg @ SubgroupMaxF(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMax, ctx) + + case sg @ SubgroupMaxF16(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMax, ctx) + + case sg @ SubgroupBroadcast(v, lane) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformBroadcast, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ResultRef(ctx.exprRefs(lane.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + + case sg @ SubgroupBroadcastFirst(v) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformBroadcastFirst, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + + case sg @ SubgroupShuffle(v, lane) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformShuffle, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ResultRef(ctx.exprRefs(lane.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + + case sg @ SubgroupShuffleXor(v, mask) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformShuffleXor, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ResultRef(ctx.exprRefs(mask.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + case d @ ReadUniform(u) => (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u)))) case c: ConvertExpression[?, ?] => compileConvertExpression(c, ctx) + case cvf @ ConvertVec4F16ToF32(v) => + // Convert Vec4[Float16] to Vec4[Float32] using OpFConvert + val vec4F32TypeRef = ctx.valueTypeMap(cvf.tag.tag) + val instructions = List( + Instruction(Op.OpFConvert, List(ResultRef(vec4F32TypeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(v.treeid)))) + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cvf.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + case b: BinaryOpExpression[?] => compileBinaryOpExpression(b, ctx) @@ -306,6 +463,24 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) (instructions, updatedContext) + case ReadShared(buffer, i) => + val sharedId = buffer.asInstanceOf[GShared.GSharedImpl[?]].sharedId + val sharedBlock = ctx.sharedVarRefs(sharedId) + val instructions = List( + Instruction( + Op.OpAccessChain, + List( + ResultRef(sharedBlock.pointerTypeRef), + ResultRef(ctx.nextResultId), + ResultRef(sharedBlock.varRef), + ResultRef(ctx.exprRefs(i.treeid)), + ), + ), + Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) + (instructions, updatedContext) + case when: WhenExpr[?] => compileWhen(when, ctx) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala index 11adc24c..5f8c6af7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala @@ -1,26 +1,28 @@ package io.computenode.cyfra.spirv.compilers +import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.gio.GIO.{CurrentFoldRepeatAcc, CurrentRepeatIndex, FlatMap, FoldRepeat, Printf, Pure, Repeat, WorkgroupBarrier} +import io.computenode.cyfra.dsl.binding.{WriteBuffer, WriteShared} import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.binding.* -import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex import io.computenode.cyfra.spirv.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} -import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} +import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag} + +import scala.collection.mutable object GIOCompiler: def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) = gio match - case GIO.Pure(v) => + case Pure(v) => val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx) (acc ::: insts, updatedCtx) case WriteBuffer(buffer, index, value) => val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) - val insns = List( Instruction( Op.OpAccessChain, @@ -35,78 +37,50 @@ object GIOCompiler: Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), ) val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) - (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx) + // valueInsts before indexInsts: value compiled first, may define exprs index uses + (acc ::: valueInsts ::: indexInsts ::: insns, updatedCtx) - case GIO.FlatMap(v, n) => + case FlatMap(v, n) => val (vInsts, ctxAfterV) = compileGio(v, ctx, acc) compileGio(n, ctxAfterV, vInsts) - case GIO.Repeat(n, f) => - // Compile 'n' first (so we can use its id in the comparison) - val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) - - // Types and constants - val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) - val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) - val zeroId = ctxWithN.constRefs((Int32Tag, 0)) - val oneId = ctxWithN.constRefs((Int32Tag, 1)) - val nId = ctxWithN.exprRefs(n.tree.treeid) - - // Reserve ids for blocks and results - val baseId = ctxWithN.nextResultId - val preHeaderId = baseId - val headerId = baseId + 1 - val bodyId = baseId + 2 - val continueId = baseId + 3 - val mergeId = baseId + 4 - val phiId = baseId + 5 - val cmpId = baseId + 6 - val addId = baseId + 7 - - // Bind CurrentRepeatIndex to the phi result for body compilation - val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId)) - val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation - - // Preheader: close current block and jump to header through a dedicated block - val preheader = List( - Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), - Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), - Instruction(Op.OpBranch, List(ResultRef(headerId))), + case r @ Repeat(n, f, unroll) => + compileRepeat(n, f, unroll, ctx, acc) + + case fr: FoldRepeat[?] => + compileFoldRepeat(fr, ctx, acc) + + case WorkgroupBarrier => + val scopeId = ctx.constRefs((Int32Tag, Scope.Workgroup.opcode)) + val semanticsId = ctx.constRefs((Int32Tag, MemorySemantics.WorkgroupMemory.opcode | MemorySemantics.AcquireRelease.opcode)) + val barrierInsn = Instruction( + Op.OpControlBarrier, + List(ResultRef(scopeId), ResultRef(scopeId), ResultRef(semanticsId)), ) + (acc ::: List(barrierInsn), ctx) - // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch - val header = List( - Instruction(Op.OpLabel, List(ResultRef(headerId))), - // OpPhi must be first in the block + case WriteShared(buffer, index, value) => + val sharedId = buffer.asInstanceOf[io.computenode.cyfra.dsl.binding.GShared.GSharedImpl[?]].sharedId + val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) + val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) + val sharedBlock = ctxWithIndex.sharedVarRefs(sharedId) + val insns = List( Instruction( - Op.OpPhi, - List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), + Op.OpAccessChain, + List( + ResultRef(sharedBlock.pointerTypeRef), + ResultRef(ctxWithIndex.nextResultId), + ResultRef(sharedBlock.varRef), + ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), + ), ), - // cmp = (counter < n) - Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), - // OpLoopMerge must be the second-to-last instruction, before the terminating branch - Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)), - Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), - ) - - val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) - - val contBlk = List( - Instruction(Op.OpLabel, List(ResultRef(continueId))), - Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), - Instruction(Op.OpBranch, List(ResultRef(headerId))), + Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), ) + val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) + // valueInsts before indexInsts: value compiled first, may define exprs index uses + (acc ::: valueInsts ::: indexInsts ::: insns, updatedCtx) - val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) - - // Use the highest nextResultId to avoid ID collisions - val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId - // Use ctxWithN as base to prevent loop-local values from being referenced outside - val finalCtx = ctxWithN.copy(nextResultId = finalNextId) - - (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) - - case GIO.Printf(format, args*) => + case Printf(format, args*) => val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) => val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc) (instsAcc ::: argInsts, cAfterArg) @@ -123,3 +97,223 @@ object GIOCompiler: ) ::: argResults, ) (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1)) + + private def compileRepeat( + n: io.computenode.cyfra.dsl.Value.Int32, + f: GIO[?], + unroll: Boolean, + ctx: Context, + acc: List[Words], + ): (List[Words], Context) = + // TODO: Loop invariant optimization disabled temporarily - causes hangs in some programs + // The optimization extracts loop-invariant expressions outside the loop, but something + // is going wrong with certain expression patterns (like in EncoderProgram). + // For now, just compile n and proceed with the loop body. + val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) + + val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) + val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) + val zeroId = ctxWithN.constRefs((Int32Tag, 0)) + val oneId = ctxWithN.constRefs((Int32Tag, 1)) + val nId = ctxWithN.exprRefs(n.tree.treeid) + + val baseId = ctxWithN.nextResultId + val preHeaderId = baseId + val headerId = baseId + 1 + val bodyId = baseId + 2 + val continueId = baseId + 3 + val mergeId = baseId + 4 + val phiId = baseId + 5 + val cmpId = baseId + 6 + val addId = baseId + 7 + + val bodyCtx = ctxWithN.copy( + nextResultId = baseId + 8, + exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId), + ) + val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) + + val preheader = List( + Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), + Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val header = List( + Instruction(Op.OpLabel, List(ResultRef(headerId))), + Instruction( + Op.OpPhi, + List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), + ), + Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), + Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), + if unroll then LoopControlMask.Unroll else LoopControlMask.MaskNone)), + Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), + ) + + val bodyBlk = + List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: + bodyInsts ::: + List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) + + val contBlk = List( + Instruction(Op.OpLabel, List(ResultRef(continueId))), + Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) + + val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) + val finalCtx = ctxWithN.copy(nextResultId = finalNextId) + + (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) + + /** Compiles foldRepeat - a loop with an accumulator that can contain barriers. */ + private def compileFoldRepeat( + fr: FoldRepeat[?], + ctx: Context, + acc: List[Words], + ): (List[Words], Context) = + val n = fr.n + val init = fr.init + val body = fr.body + val accTreeId = fr.accTreeId + + // Compile n and init + val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) + val (initInsts, ctxWithInit) = ExpressionCompiler.compileBlock(init.tree, ctxWithN) + + val intTy = ctxWithInit.valueTypeMap(Int32Tag.tag) + val accTy = ctxWithInit.valueTypeMap(init.tree.tag.tag) + val boolTy = ctxWithInit.valueTypeMap(GBooleanTag.tag) + val zeroId = ctxWithInit.constRefs((Int32Tag, 0)) + val oneId = ctxWithInit.constRefs((Int32Tag, 1)) + val nId = ctxWithInit.exprRefs(n.tree.treeid) + val initId = ctxWithInit.exprRefs(init.tree.treeid) + + val baseId = ctxWithInit.nextResultId + val preHeaderId = baseId + val headerId = baseId + 1 + val bodyId = baseId + 2 + val continueId = baseId + 3 + val mergeId = baseId + 4 + val iterPhiId = baseId + 5 // loop counter phi + val accPhiId = baseId + 6 // accumulator phi + val cmpId = baseId + 7 + val addId = baseId + 8 + + // Setup context for body compilation with both loop counter and accumulator + val bodyCtx = ctxWithInit.copy( + nextResultId = baseId + 9, + exprRefs = ctxWithInit.exprRefs + + (CurrentRepeatIndex.treeid -> iterPhiId) + + (accTreeId -> accPhiId), + ) + + val (bodyInsts, ctxAfterBody) = compileGio(body, bodyCtx) + val bodyResultId = ctxAfterBody.exprRefs(body.underlying.tree.treeid) + + val preheader = List( + Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), + Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val header = List( + Instruction(Op.OpLabel, List(ResultRef(headerId))), + // Phi for loop counter + Instruction( + Op.OpPhi, + List(ResultRef(intTy), ResultRef(iterPhiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), + ), + // Phi for accumulator + Instruction( + Op.OpPhi, + List(ResultRef(accTy), ResultRef(accPhiId), ResultRef(initId), ResultRef(preHeaderId), ResultRef(bodyResultId), ResultRef(continueId)), + ), + Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(iterPhiId), ResultRef(nId))), + Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), + if fr.unroll then LoopControlMask.Unroll else LoopControlMask.MaskNone)), + Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), + ) + + val bodyBlk = + List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: + bodyInsts ::: + List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) + + val contBlk = List( + Instruction(Op.OpLabel, List(ResultRef(continueId))), + Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(iterPhiId), ResultRef(oneId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) + + val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) + // The result of foldRepeat is the final accumulator value (accPhiId after merge) + // We need to map both: + // 1. The accumulator phantom treeid (for expressions that reference the accumulator) + // 2. The body result treeid (for the FlatMap chain to work correctly) + val finalCtx = ctxAfterBody.copy( + nextResultId = finalNextId, + exprRefs = ctxAfterBody.exprRefs + + (accTreeId -> accPhiId) + + (body.underlying.tree.treeid -> accPhiId), + ) + + (acc ::: nInsts ::: initInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) + + /** Finds the CurrentFoldRepeatAcc phantom expression in a GIO tree. */ + private def findFoldRepeatAcc(gio: GIO[?]): Option[CurrentFoldRepeatAcc[?]] = + def findInExpr(expr: E[?]): Option[CurrentFoldRepeatAcc[?]] = + expr match + case acc: CurrentFoldRepeatAcc[?] => Some(acc) + case _ => expr.exprDependencies.flatMap(findInExpr).headOption + + def findInGio(g: GIO[?]): Option[CurrentFoldRepeatAcc[?]] = g match + case Pure(v) => findInExpr(v.tree) + case FlatMap(v, n) => findInGio(v).orElse(findInGio(n)) + case Repeat(n, body, _) => findInExpr(n.tree).orElse(findInGio(body)) + case FoldRepeat(n, init, b, _, _) => findInExpr(n.tree).orElse(findInExpr(init.tree)).orElse(findInGio(b)) + case WriteBuffer(_, i, v) => findInExpr(i.tree).orElse(findInExpr(v.tree)) + case WriteShared(_, i, v) => findInExpr(i.tree).orElse(findInExpr(v.tree)) + case Printf(_, args*) => args.flatMap(a => findInExpr(a.tree)).headOption + case WorkgroupBarrier => None + + findInGio(gio) + + private def collectExpressionsMap(gio: GIO[?]): Map[Int, E[?]] = + val result = mutable.Map[Int, E[?]]() + + def collectFromExpr(expr: E[?]): Unit = + if !result.contains(expr.treeid) then + result += (expr.treeid -> expr) + expr.exprDependencies.foreach(collectFromExpr) + + def collectFromGio(g: GIO[?]): Unit = g match + case Pure(v) => collectFromExpr(v.tree) + case FlatMap(v, n) => collectFromGio(v); collectFromGio(n) + case Repeat(n, body, _) => collectFromExpr(n.tree); collectFromGio(body) + case FoldRepeat(n, init, body, _, _) => collectFromExpr(n.tree); collectFromExpr(init.tree); collectFromGio(body) + case WriteBuffer(_, i, v) => collectFromExpr(i.tree); collectFromExpr(v.tree) + case WriteShared(_, i, v) => collectFromExpr(i.tree); collectFromExpr(v.tree) + case Printf(_, args*) => args.foreach(a => collectFromExpr(a.tree)) + case WorkgroupBarrier => () // No expressions to collect + + collectFromGio(gio) + result.toMap + + private def findLoopDependentExprs(exprsMap: Map[Int, E[?]], loopVarId: Int): Set[Int] = + val dependent = mutable.Set[Int](loopVarId) + var changed = true + while changed do + changed = false + exprsMap.values.foreach: expr => + if !dependent.contains(expr.treeid) then + // Check if any dependency's treeid is in dependent set + if expr.exprDependencies.exists(dep => dependent.contains(dep.treeid)) then + dependent += expr.treeid + changed = true + dependent.toSet diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala index e635c4c5..57299b1f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala @@ -68,7 +68,7 @@ private[cyfra] object GSeqCompiler: ::: List( // acc = nextAcc Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(reduceCtx.exprRefs(foldFnExpr.treeid)))), ) - (instructions, ctx.joinNested(reduceCtx)) + (instructions, context.joinNested(reduceCtx)) case (op, dExpr) :: tail => op match @@ -176,7 +176,8 @@ private[cyfra] object GSeqCompiler: ), Instruction(Op.OpBranch, List(ResultRef(loopBack))), Instruction(Op.OpLabel, List(ResultRef(loopBack))), - Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), LoopControlMask.MaskNone)), + Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), + if fold.unroll then LoopControlMask.Unroll else LoopControlMask.MaskNone)), Instruction(Op.OpBranch, List(ResultRef(postLoopMergeLabel))), Instruction(Op.OpLabel, List(ResultRef(postLoopMergeLabel))), Instruction(Op.OpLoad, List(ResultRef(boolType), ResultRef(shouldTakeInCheck), ResultRef(shouldTakeVar))), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala index bd4e469c..4e32c1d5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala @@ -22,31 +22,72 @@ private[cyfra] object SpirvProgramCompiler: case _ => false def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) = + val int32TypeRef = ctx.valueTypeMap(Int32Tag.tag) + val vec3Int32TypeRef = ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag) + val int32PtrInputRef = ctx.inputPointerMap(int32TypeRef) + val vec3Int32PtrInputRef = ctx.inputPointerMap(vec3Int32TypeRef) + val zeroConstRef = ctx.constRefs(Int32Tag, 0) val init = List( Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), Instruction(Op.OpLabel, List(ResultRef(ctx.nextResultId))), ) - val initWorkerIndex = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(Int32Tag.tag))), - ResultRef(ctx.nextResultId + 1), - ResultRef(GL_GLOBAL_INVOCATION_ID_REF), - ResultRef(ctx.constRefs(Int32Tag, 0)), - ), - ), - Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))), + var nextId = ctx.nextResultId + 1 + + def loadScalarFromVec3(varRef: Int): (List[Words], Int) = + val ptrId = nextId + val loadId = nextId + 1 + nextId += 2 + val insns = List( + Instruction(Op.OpAccessChain, List(ResultRef(int32PtrInputRef), ResultRef(ptrId), ResultRef(varRef), ResultRef(zeroConstRef))), + Instruction(Op.OpLoad, List(ResultRef(int32TypeRef), ResultRef(loadId), ResultRef(ptrId))), + ) + (insns, loadId) + + def loadVec3(varRef: Int): (List[Words], Int) = + val loadId = nextId + nextId += 1 + val insns = List(Instruction(Op.OpLoad, List(ResultRef(vec3Int32TypeRef), ResultRef(loadId), ResultRef(varRef)))) + (insns, loadId) + + def loadScalar(varRef: Int): (List[Words], Int) = + val loadId = nextId + nextId += 1 + val insns = List(Instruction(Op.OpLoad, List(ResultRef(int32TypeRef), ResultRef(loadId), ResultRef(varRef)))) + (insns, loadId) + + val (globalInvocInsns, globalInvocId) = loadScalarFromVec3(GL_GLOBAL_INVOCATION_ID_REF) + val (localIdInsns, localIdRef) = loadVec3(GL_LOCAL_INVOCATION_ID_REF) + val (localIndexInsns, localIndexRef) = loadScalar(GL_LOCAL_INVOCATION_INDEX_REF) + val (workgroupIdInsns, workgroupIdLoadRef) = loadVec3(GL_WORKGROUP_ID_REF) + val (numWorkgroupsInsns, numWorkgroupsLoadRef) = loadVec3(GL_NUM_WORKGROUPS_REF) + val (subgroupIdInsns, subgroupIdLoadRef) = loadScalar(GL_SUBGROUP_ID_REF) + val (subgroupLocalIdInsns, subgroupLocalIdLoadRef) = loadScalar(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF) + val (subgroupSizeInsns, subgroupSizeLoadRef) = loadScalar(GL_SUBGROUP_SIZE_REF) + + val loadInsns = globalInvocInsns ::: localIdInsns ::: localIndexInsns ::: + workgroupIdInsns ::: numWorkgroupsInsns ::: + subgroupIdInsns ::: subgroupLocalIdInsns ::: subgroupSizeInsns + + val bodyCtx = ctx.copy( + nextResultId = nextId, + workerIndexRef = globalInvocId, + localInvocationIdRef = localIdRef, + localInvocationIndexRef = localIndexRef, + workgroupIdRef = workgroupIdLoadRef, + numWorkgroupsRef = numWorkgroupsLoadRef, + subgroupIdRef = subgroupIdLoadRef, + subgroupLocalInvocationIdRef = subgroupLocalIdLoadRef, + subgroupSizeRef = subgroupSizeLoadRef, ) - val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) + val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, bodyCtx) val (vars, nonVarsBody) = bubbleUpVars(body) val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List())) - (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) + (init ::: vars ::: loadInsns ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) def getNameDecorations(ctx: Context): List[Instruction] = val funNames = ctx.functions.map { case (id, fn) => @@ -65,25 +106,58 @@ private[cyfra] object SpirvProgramCompiler: binding: Int, ) - val headers: List[Words] = + case class SharedBlock( + arrayTypeRef: Int, + varRef: Int, + pointerTypeRef: Int, + ) + + def headers(workgroupSize: (Int, Int, Int)): List[Words] = + val (localSizeX, localSizeY, localSizeZ) = workgroupSize Word(Array(0x03, 0x02, 0x23, 0x07)) :: // SPIR-V Word(Array(0x00, 0x00, 0x01, 0x00)) :: // Version: 0.1.0 Word(Array(cyfraVendorId, 0x00, 0x01, 0x00)) :: // Generator: cyfra; 1 WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0 - Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader - Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info" - Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450" - Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf" - Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450 - Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF - Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1 - Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: // OpSource GLSL 450 + Instruction(Op.OpCapability, List(Capability.Shader)) :: + Instruction(Op.OpCapability, List(Capability.GroupNonUniform)) :: + Instruction(Op.OpCapability, List(Capability.GroupNonUniformArithmetic)) :: + Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: + Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: + Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: + Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: + Instruction( + Op.OpEntryPoint, + List( + ExecutionModel.GLCompute, + ResultRef(MAIN_FUNC_REF), + Text("main"), + ResultRef(GL_GLOBAL_INVOCATION_ID_REF), + ResultRef(GL_LOCAL_INVOCATION_ID_REF), + ResultRef(GL_LOCAL_INVOCATION_INDEX_REF), + ResultRef(GL_WORKGROUP_ID_REF), + ResultRef(GL_NUM_WORKGROUPS_REF), + ResultRef(GL_SUBGROUP_ID_REF), + ResultRef(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF), + ResultRef(GL_SUBGROUP_SIZE_REF), + ), + ) :: + Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(localSizeX), IntWord(localSizeY), IntWord(localSizeZ))) :: + Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: Nil val workgroupDecorations: List[Words] = - Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId - Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil + List( + Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)), + Instruction(Op.OpDecorate, List(ResultRef(GL_LOCAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.LocalInvocationId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_LOCAL_INVOCATION_INDEX_REF), Decoration.BuiltIn, BuiltIn.LocalInvocationIndex)), + Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_ID_REF), Decoration.BuiltIn, BuiltIn.WorkgroupId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_NUM_WORKGROUPS_REF), Decoration.BuiltIn, BuiltIn.NumWorkgroups)), + Instruction(Op.OpDecorate, List(ResultRef(GL_SUBGROUP_ID_REF), Decoration.BuiltIn, BuiltIn.SubgroupId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.SubgroupLocalInvocationId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_SUBGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.SubgroupSize)), + ) def defineVoids(context: Context): (List[Words], Context) = val voidDef = List[Words]( @@ -93,7 +167,8 @@ private[cyfra] object SpirvProgramCompiler: val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) (voidDef, ctxWithVoid) - def createInvocationId(context: Context): (List[Words], Context) = + def createInvocationId(context: Context, workgroupSize: (Int, Int, Int)): (List[Words], Context) = + val (localSizeX, localSizeY, localSizeZ) = workgroupSize val definitionInstructions = List( Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), @@ -239,7 +314,14 @@ private[cyfra] object SpirvProgramCompiler: } } - val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) + val predefinedConsts = List( + (Int32Tag, 0), + (UInt32Tag, 0), + (Int32Tag, 1), + (Int32Tag, Scope.Workgroup.opcode), + (Int32Tag, Scope.Subgroup.opcode), + (Int32Tag, MemorySemantics.WorkgroupMemory.opcode | MemorySemantics.AcquireRelease.opcode), + ) def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = // Collect field indices from GetField expressions val fieldIndices = exprs.collect { case gf: GetField[?, ?] => @@ -269,16 +351,18 @@ private[cyfra] object SpirvProgramCompiler: ) def defineVarNames(ctx: Context): (List[Words], Context) = + val vec3Int32PtrId = ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag)) + val int32PtrInputId = ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Int32]].tag)) ( List( - Instruction( - Op.OpVariable, - List( - ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag))), - ResultRef(GL_GLOBAL_INVOCATION_ID_REF), - StorageClass.Input, - ), - ), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_GLOBAL_INVOCATION_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_LOCAL_INVOCATION_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_LOCAL_INVOCATION_INDEX_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_WORKGROUP_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_NUM_WORKGROUPS_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_SUBGROUP_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_SUBGROUP_SIZE_REF), StorageClass.Input)), ), - ctx.copy(), + ctx, ) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala index 7d52eb5e..283fd465 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala @@ -87,15 +87,20 @@ object Expression: sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]: def fromTag: Tag[F] = summon[Tag[F]] def a: F + case class ToFloat16[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float16] case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32] case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32] case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32] + /** Convert Vec4[Float16] to Vec4[Float32] using OpFConvert. */ + case class ConvertVec4F16ToF32(a: Vec4[Float16]) extends Expression[Vec4[Float32]] + sealed trait Const[T <: Scalar: Tag] extends Expression[T]: def value: Any object Const: def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value) + case class ConstFloat16(value: Float) extends Const[Float16] case class ConstFloat32(value: Float) extends Const[Float32] case class ConstInt32(value: Int) extends Const[Int32] case class ConstUInt32(value: Int) extends Const[UInt32] @@ -115,3 +120,33 @@ object Expression: case object WorkerIndex extends E[Int32] case class Binding[T <: Value: Tag](binding: Int) extends E[T] + + // Workgroup built-ins + case object LocalInvocationIndex extends E[Int32] + case object LocalInvocationId extends E[Vec3[Int32]] + case object WorkgroupId extends E[Vec3[Int32]] + case object NumWorkgroups extends E[Vec3[Int32]] + case object SubgroupId extends E[Int32] + case object SubgroupLocalInvocationId extends E[Int32] + case object SubgroupSize extends E[Int32] + + // Subgroup operations + sealed trait SubgroupOp + object SubgroupOp: + case object Reduce extends SubgroupOp + case object InclusiveScan extends SubgroupOp + case object ExclusiveScan extends SubgroupOp + + case class SubgroupAddI(value: Int32, op: SubgroupOp) extends E[Int32] + case class SubgroupAddF16(value: Float16, op: SubgroupOp) extends E[Float16] + case class SubgroupAddF(value: Float32, op: SubgroupOp) extends E[Float32] + case class SubgroupMinI(value: Int32, op: SubgroupOp) extends E[Int32] + case class SubgroupMinF16(value: Float16, op: SubgroupOp) extends E[Float16] + case class SubgroupMinF(value: Float32, op: SubgroupOp) extends E[Float32] + case class SubgroupMaxI(value: Int32, op: SubgroupOp) extends E[Int32] + case class SubgroupMaxF16(value: Float16, op: SubgroupOp) extends E[Float16] + case class SubgroupMaxF(value: Float32, op: SubgroupOp) extends E[Float32] + case class SubgroupBroadcast[T <: Value.Scalar: Tag](value: T, lane: Int32) extends E[T] + case class SubgroupBroadcastFirst[T <: Value.Scalar: Tag](value: T) extends E[T] + case class SubgroupShuffle[T <: Value.Scalar: Tag](value: T, lane: Int32) extends E[T] + case class SubgroupShuffleXor[T <: Value.Scalar: Tag](value: T, mask: Int32) extends E[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala index 1e8a0e92..de4bd094 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala @@ -24,6 +24,16 @@ object Value: sealed trait Scalar extends Value trait FloatType extends Scalar + + /** 16-bit floating point (half precision) - supported in Vulkan for memory bandwidth savings */ + case class Float16(tree: E[Float16])(using val source: Source) extends FloatType + given FromExpr[Float16] with + def fromExpr(f: E[Float16])(using Source) = Float16(f) + + /** Factory method for creating Float16 constants */ + object Float16: + def apply(value: Float)(using Source): Float16 = Float16(Expression.ConstFloat16(value)) + case class Float32(tree: E[Float32])(using val source: Source) extends FloatType given FromExpr[Float32] with def fromExpr(f: E[Float32])(using Source) = Float32(f) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala index 475b936e..31403fe3 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.dsl.algebra -import io.computenode.cyfra.dsl.Expression.ConstFloat32 +import io.computenode.cyfra.dsl.Expression.{ConstFloat16, ConstFloat32} import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.library.Functions.abs @@ -22,6 +22,7 @@ object ScalarAlgebra: trait BasicScalarIntAlgebra[T <: Scalar: {FromExpr, Tag}] extends BasicScalarAlgebra[T] with BitwiseOperable[T] + given BasicScalarAlgebra[Float16] = new BasicScalarAlgebra[Float16] {} given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} @@ -92,16 +93,27 @@ object ScalarAlgebra: given Epsilon = Epsilon(0.00001f) + extension (f16: Float16) + inline def asFloat32(using Source): Float32 = Float32(ToFloat32(f16)) + inline def asInt(using Source): Int32 = f16.asFloat32.asInt + + extension (f32: Float32) + inline def asFloat16(using Source): Float16 = Float16(ToFloat16(f32)) + extension (f32: Float32) + /** Convert Float32 to Float16 constant for DSL usage */ + inline def toF16(using Source): Float16 = Float16(ToFloat16(f32)) inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = abs(f32 - other) < epsilon.eps extension (i32: Int32) + inline def asFloat16(using Source): Float16 = Float16(ToFloat16(i32)) inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) - + extension (u32: UInt32) + inline def asFloat16(using Source): Float16 = Float16(ToFloat16(u32)) inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) inline def signed(using Source): Int32 = Int32(ToInt32(u32)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala index 7908f63f..eee2b0e0 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala @@ -119,6 +119,10 @@ object VectorAlgebra: inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) inline def rgb(using Source): Vec3[T] = xyz + /** Convert Vec4[Float16] to Vec4[Float32] for higher precision operations. */ + extension (v4f16: Vec4[Float16]) + inline def asVec4F32(using Source): Vec4[Float32] = Vec4(ConvertVec4F16ToF32(v4f16)) + given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala new file mode 100644 index 00000000..3d4c55d5 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala @@ -0,0 +1,39 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +/** + * Represents a workgroup-local shared memory array. + * + * Shared memory is visible to all invocations within a workgroup and can be used + * for efficient inter-thread communication within a workgroup after synchronization + * with [[GIO.barrier]]. + * + * @tparam T Element type of the shared memory array + */ +trait GShared[T <: Value: {FromExpr, Tag}]: + def tag: Tag[T] = summon[Tag[T]] + def size: Int + + /** Read a value from shared memory at the given index. */ + def read(index: Int32): T = fromExpr(ReadShared(this, index)) + + /** Write a value to shared memory at the given index. */ + def write(index: Int32, value: T): GIO[Empty] = WriteShared(this, index, value) + +object GShared: + private var nextId = 0 + + /** Create a shared memory array with the given size. */ + def apply[T <: Value: {FromExpr, Tag}](size: Int): GShared[T] = + val id = nextId + nextId += 1 + new GSharedImpl[T](id, size) + + private[cyfra] class GSharedImpl[T <: Value: {FromExpr, Tag}](val sharedId: Int, val size: Int) extends GShared[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala new file mode 100644 index 00000000..8da5c592 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Expression +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import izumi.reflect.Tag + +case class ReadShared[T <: Value: {Tag, FromExpr}]( + buffer: GShared[T], + index: Int32, +) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala new file mode 100644 index 00000000..a05cb5fb --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +case class WriteShared[T <: Value: {Tag, FromExpr}]( + buffer: GShared[T], + index: Int32, + value: T, +) extends GIO[Empty]: + override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala index b4265a1b..03eb23d2 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala @@ -17,6 +17,7 @@ class GSeq[T <: Value: {Tag, FromExpr}]( val name: Source, val currentElemExprTreeId: Int = treeidState.getAndIncrement(), val aggregateElemExprTreeId: Int = treeidState.getAndIncrement(), + val shouldUnroll: Boolean = false, ): def copyWithDynamicTrees[R <: Value: {Tag, FromExpr}]( @@ -24,7 +25,8 @@ class GSeq[T <: Value: {Tag, FromExpr}]( limit: Option[Int] = limit, currentElemExprTreeId: Int = currentElemExprTreeId, aggregateElemExprTreeId: Int = aggregateElemExprTreeId, - ) = GSeq[R](uninitSource, elemOps, limit, name, currentElemExprTreeId, aggregateElemExprTreeId) + shouldUnroll: Boolean = shouldUnroll, + ) = GSeq[R](uninitSource, elemOps, limit, name, currentElemExprTreeId, aggregateElemExprTreeId, shouldUnroll) private val currentElemExpr = CurrentElem[T](currentElemExprTreeId) val source = uninitSource(currentElemExpr) @@ -43,8 +45,15 @@ class GSeq[T <: Value: {Tag, FromExpr}]( def limit(n: Int): GSeq[T] = this.copyWithDynamicTrees(limit = Some(n)) + /** Mark this sequence for loop unrolling in the generated shader. + * This generates [[unroll]] pragma in GLSL, which hints the compiler + * to fully unroll the loop for better performance on small fixed-size loops. + */ + def unroll: GSeq[T] = + this.copyWithDynamicTrees(shouldUnroll = true) + def fold[R <: Value: {Tag, FromExpr}](zero: R, fn: (R, T) => R): R = - summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this)) + summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this, shouldUnroll)) def count: Int32 = fold(0, (acc: Int32, _: T) => acc + 1) @@ -90,7 +99,7 @@ object GSeq: sealed trait GSeqSource[T <: Value: Tag] case class GSeqStream[T <: Value: Tag](init: T, next: Expression[?]) extends GSeqSource[T] - case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T]) extends Expression[R]: + case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T], unroll: Boolean = false) extends Expression[R]: val zeroExpr = zero.tree val fnExpr = fn val streamInitExpr = seq.source.init.tree diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala index 09373068..15d10027 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala @@ -1,7 +1,8 @@ package io.computenode.cyfra.dsl.gio import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.Expression.{CustomTreeId, PhantomExpression, treeidState, *, given} +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32, UInt32, Float16, Float32, Vec3, Vec4} import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer, WriteBuffer} import io.computenode.cyfra.dsl.collections.GSeq @@ -10,6 +11,10 @@ import io.computenode.cyfra.dsl.struct.GStruct.Empty import io.computenode.cyfra.dsl.control.When import izumi.reflect.Tag +/** + * GPU I/O monad for representing side-effectful GPU operations. + * Supports buffer reads/writes, synchronization barriers, and workgroup-level operations. + */ trait GIO[T <: Value]: def flatMap[U <: Value](f: T => GIO[U]): GIO[U] = FlatMap(this, f(this.underlying)) @@ -26,13 +31,33 @@ object GIO: case class FlatMap[T <: Value, U <: Value](gio: GIO[T], next: GIO[U]) extends GIO[U]: override def underlying: U = next.underlying - // TODO repeat that collects results - case class Repeat(n: Int32, f: GIO[?]) extends GIO[Empty]: + /** Loop that repeats n times without accumulator. + * + * @param n Number of iterations + * @param f Body GIO to execute + * @param unroll Whether to hint the GPU compiler to unroll this loop + */ + case class Repeat(n: Int32, f: GIO[?], unroll: Boolean = false) extends GIO[Empty]: override def underlying: Empty = Empty() + /** Folding repeat with accumulator - enables accumulation across iterations with barriers. + * + * @param n Number of iterations + * @param init Initial accumulator value + * @param body Body GIO that returns new accumulator + * @param accTreeId Treeid of the CurrentFoldRepeatAcc phantom for binding + * @param unroll Whether to hint the GPU compiler to unroll this loop + */ + case class FoldRepeat[A <: Value](n: Int32, init: A, body: GIO[A], accTreeId: Int, unroll: Boolean = false) extends GIO[A]: + override def underlying: A = body.underlying + case class Printf(format: String, args: Value*) extends GIO[Empty]: override def underlying: Empty = Empty() + /** Memory and execution barrier for workgroup synchronization. */ + case object WorkgroupBarrier extends GIO[Empty]: + override def underlying: Empty = Empty() + def pure[T <: Value](value: T): GIO[T] = Pure(value) def value[T <: Value](value: T): GIO[T] = Pure(value) @@ -40,8 +65,48 @@ object GIO: case object CurrentRepeatIndex extends PhantomExpression[Int32] with CustomTreeId: override val treeid: Int = treeidState.getAndIncrement() + /** Phantom expression for the current accumulator value in foldRepeat. */ + case class CurrentFoldRepeatAcc[A <: Value: Tag](init: A, tid: Int) extends PhantomExpression[A] with CustomTreeId: + override val treeid: Int = tid + def repeat(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] = - Repeat(n, f(fromExpr(CurrentRepeatIndex))) + Repeat(n, f(fromExpr(CurrentRepeatIndex)), unroll = false) + + /** Repeat with loop unroll hint. The GPU compiler will attempt to fully unroll + * this loop for better performance. Use for small, fixed-size loops. + */ + def repeatUnroll(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] = + Repeat(n, f(fromExpr(CurrentRepeatIndex)), unroll = true) + + /** Folding repeat - accumulates a value across iterations, supporting barriers. + * + * Unlike `GSeq.fold`, this supports side effects (barriers, writes) within the loop body. + * The body receives the current iteration index and current accumulator value, + * and returns the new accumulator value wrapped in GIO. + * + * @param n Number of iterations + * @param init Initial accumulator value + * @param body Function taking (iterationIndex, currentAcc) and returning new acc in GIO + * @return Final accumulated value + */ + def foldRepeat[A <: Value: {FromExpr, Tag}](n: Int32, init: A)(body: (Int32, A) => GIO[A]): GIO[A] = + val tid = treeidState.getAndIncrement() + val accExpr = CurrentFoldRepeatAcc(init, tid) + FoldRepeat(n, init, body(fromExpr(CurrentRepeatIndex), fromExpr(accExpr)), tid, unroll = false) + + /** Folding repeat with loop unroll hint. The GPU compiler will attempt to fully + * unroll this loop for better performance. Use for small, fixed-size inner loops + * (e.g., head dimension in attention, vector dot products). + * + * @param n Number of iterations (should be a small constant for effective unrolling) + * @param init Initial accumulator value + * @param body Function taking (iterationIndex, currentAcc) and returning new acc in GIO + * @return Final accumulated value + */ + def foldRepeatUnroll[A <: Value: {FromExpr, Tag}](n: Int32, init: A)(body: (Int32, A) => GIO[A]): GIO[A] = + val tid = treeidState.getAndIncrement() + val accExpr = CurrentFoldRepeatAcc(init, tid) + FoldRepeat(n, init, body(fromExpr(CurrentRepeatIndex), fromExpr(accExpr)), tid, unroll = true) def write[T <: Value](buffer: GBuffer[T], index: Int32, value: T): GIO[Empty] = WriteBuffer(buffer, index, value) @@ -57,5 +122,142 @@ object GIO: def read[T <: Value: {FromExpr, Tag}](buffer: GBuffer[T], index: Int32): T = fromExpr(ReadBuffer(buffer, index)) + import scala.annotation.targetName + + // ───────────────────────────────────────────────────────────────────────────── + // Global Invocation + // ───────────────────────────────────────────────────────────────────────────── + + /** Global invocation index (gl_GlobalInvocationID.x). */ def invocationId: Int32 = fromExpr(InvocationId) + + // ───────────────────────────────────────────────────────────────────────────── + // Workgroup Primitives + // ───────────────────────────────────────────────────────────────────────────── + + /** Local invocation index within workgroup (gl_LocalInvocationIndex). */ + def localInvocationIndex: Int32 = + fromExpr(LocalInvocationIndex) + + /** Local invocation ID as 3D vector (gl_LocalInvocationID). */ + def localInvocationId: Vec3[Int32] = + fromExpr(LocalInvocationId) + + /** Workgroup ID as 3D vector (gl_WorkGroupID). */ + def workgroupId: Vec3[Int32] = + fromExpr(WorkgroupId) + + /** Number of workgroups as 3D vector (gl_NumWorkGroups). */ + def numWorkgroups: Vec3[Int32] = + fromExpr(NumWorkgroups) + + /** Synchronization barrier for workgroup memory and execution. */ + def barrier: GIO[Empty] = WorkgroupBarrier + + // ───────────────────────────────────────────────────────────────────────────── + // Subgroup Primitives + // ───────────────────────────────────────────────────────────────────────────── + + /** Subgroup ID within the workgroup. */ + def subgroupId: Int32 = + fromExpr(SubgroupId) + + /** Local invocation ID within the subgroup. */ + def subgroupLocalInvocationId: Int32 = + fromExpr(SubgroupLocalInvocationId) + + /** Size of subgroup (typically 32 for NVIDIA, 64 for AMD). */ + def subgroupSize: Int32 = + fromExpr(SubgroupSize) + + // ───────────────────────────────────────────────────────────────────────────── + // Subgroup Collective Operations + // ───────────────────────────────────────────────────────────────────────────── + + /** Reduces values across the subgroup using addition. */ + def subgroupAdd(value: Int32): Int32 = + fromExpr(SubgroupAddI(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using addition. */ + @targetName("subgroupAddF16") + def subgroupAdd(value: Float16): Float16 = + fromExpr(SubgroupAddF16(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using addition. */ + def subgroupAdd(value: Float32): Float32 = + fromExpr(SubgroupAddF(value, SubgroupOp.Reduce)) + + /** Inclusive prefix sum across the subgroup. */ + def subgroupInclusiveAdd(value: Int32): Int32 = + fromExpr(SubgroupAddI(value, SubgroupOp.InclusiveScan)) + + /** Inclusive prefix sum across the subgroup. */ + @targetName("subgroupInclusiveAddF16") + def subgroupInclusiveAdd(value: Float16): Float16 = + fromExpr(SubgroupAddF16(value, SubgroupOp.InclusiveScan)) + + /** Inclusive prefix sum across the subgroup. */ + def subgroupInclusiveAdd(value: Float32): Float32 = + fromExpr(SubgroupAddF(value, SubgroupOp.InclusiveScan)) + + /** Exclusive prefix sum across the subgroup. */ + def subgroupExclusiveAdd(value: Int32): Int32 = + fromExpr(SubgroupAddI(value, SubgroupOp.ExclusiveScan)) + + /** Exclusive prefix sum across the subgroup. */ + @targetName("subgroupExclusiveAddF16") + def subgroupExclusiveAdd(value: Float16): Float16 = + fromExpr(SubgroupAddF16(value, SubgroupOp.ExclusiveScan)) + + /** Exclusive prefix sum across the subgroup. */ + def subgroupExclusiveAdd(value: Float32): Float32 = + fromExpr(SubgroupAddF(value, SubgroupOp.ExclusiveScan)) + + /** Reduces values across the subgroup using minimum. */ + def subgroupMin(value: Int32): Int32 = + fromExpr(SubgroupMinI(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using minimum. */ + @targetName("subgroupMinF16") + def subgroupMin(value: Float16): Float16 = + fromExpr(SubgroupMinF16(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using minimum. */ + def subgroupMin(value: Float32): Float32 = + fromExpr(SubgroupMinF(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using maximum. */ + def subgroupMax(value: Int32): Int32 = + fromExpr(SubgroupMaxI(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using maximum. */ + @targetName("subgroupMaxF16") + def subgroupMax(value: Float16): Float16 = + fromExpr(SubgroupMaxF16(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using maximum. */ + def subgroupMax(value: Float32): Float32 = + fromExpr(SubgroupMaxF(value, SubgroupOp.Reduce)) + + /** Broadcasts a value from a specific lane to all lanes in the subgroup. */ + def subgroupBroadcast[T <: Value.Scalar: {FromExpr, Tag}](value: T, lane: Int32): T = + fromExpr(SubgroupBroadcast(value, lane)) + + /** Broadcasts a value from the first active lane to all lanes in the subgroup. */ + def subgroupBroadcastFirst[T <: Value.Scalar: {FromExpr, Tag}](value: T): T = + fromExpr(SubgroupBroadcastFirst(value)) + + /** Shuffles a value from another lane in the subgroup. */ + def subgroupShuffle[T <: Value.Scalar: {FromExpr, Tag}](value: T, lane: Int32): T = + fromExpr(SubgroupShuffle(value, lane)) + + /** Shuffles a value using XOR of lane index with mask. + * This is useful for butterfly/tree reductions where each thread exchanges + * data with thread at (laneId XOR mask). For example: + * - mask=1: lanes 0↔1, 2↔3, 4↔5, ... + * - mask=2: lanes 0↔2, 1↔3, 4↔6, ... + * - mask=4: lanes 0↔4, 1↔5, 2↔6, ... + */ + def subgroupShuffleXor[T <: Value.Scalar: {FromExpr, Tag}](value: T, mask: Int32): T = + fromExpr(SubgroupShuffleXor(value, mask)) \ No newline at end of file diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala index 26b4a970..6abd3dcb 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala @@ -13,26 +13,33 @@ object Functions: case object Sin extends FunctionName def sin(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Sin, List(v))) + def sin(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Sin, List(v))) case object Cos extends FunctionName def cos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Cos, List(v))) + def cos(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Cos, List(v))) def cos[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cos, List(v))) case object Tan extends FunctionName def tan(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Tan, List(v))) + def tan(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Tan, List(v))) case object Acos extends FunctionName def acos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Acos, List(v))) + def acos(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Acos, List(v))) case object Asin extends FunctionName def asin(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Asin, List(v))) + def asin(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Asin, List(v))) case object Atan extends FunctionName def atan(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Atan, List(v))) + def atan(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Atan, List(v))) case object Atan2 extends FunctionName def atan2(y: Float32, x: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Atan2, List(y, x))) + def atan2(y: Float16, x: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Atan2, List(y, x))) case object Len2 extends FunctionName def length[T <: Scalar: Tag](v: Vec2[T])(using Source): Float32 = Float32(ExtFunctionCall(Len2, List(v))) @@ -43,14 +50,18 @@ object Functions: case object Pow extends FunctionName def pow(v: Float32, p: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Pow, List(v, p))) + def pow(v: Float16, p: Float16)(using Source): Float16 = + Float16(ExtFunctionCall(Pow, List(v, p))) def pow[V <: Vec[?]: {Tag, FromExpr}](v: V, p: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Pow, List(v, p))) case object Smoothstep extends FunctionName def smoothstep(edge0: Float32, edge1: Float32, x: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Smoothstep, List(edge0, edge1, x))) + def smoothstep(edge0: Float16, edge1: Float16, x: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Smoothstep, List(edge0, edge1, x))) case object Sqrt extends FunctionName def sqrt(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Sqrt, List(v))) + def sqrt(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Sqrt, List(v))) case object Cross extends FunctionName def cross[T <: Scalar: Tag](v1: Vec3[T], v2: Vec3[T])(using Source): Vec3[T] = Vec3(ExtFunctionCall(Cross, List(v1, v2))) @@ -61,12 +72,14 @@ object Functions: case object Exp extends FunctionName def exp(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Exp, List(f))) + def exp(f: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Exp, List(f))) def exp[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Exp, List(v))) case object Max extends FunctionName def max(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Max, List(f1, f2))) def max(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(max(f1, f2))((a, b) => max(a, b)) + def max(f1: Float16, f2: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Max, List(f1, f2))) def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Max, List(v1, v2))) def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = @@ -75,6 +88,7 @@ object Functions: case object Min extends FunctionName def min(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Min, List(f1, f2))) def min(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(min(f1, f2))((a, b) => min(a, b)) + def min(f1: Float16, f2: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Min, List(f1, f2))) def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Min, List(v1, v2))) def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = @@ -83,6 +97,7 @@ object Functions: // todo add F/U/S to all functions that need it case object Abs extends FunctionName def abs(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Abs, List(f))) + def abs(f: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Abs, List(f))) def abs[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Abs, List(v))) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala new file mode 100644 index 00000000..282682a0 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala @@ -0,0 +1,280 @@ +package io.computenode.cyfra.e2e.dsl + +import io.computenode.cyfra.core.{GBufferRegion, GProgram} +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.{GBuffer, GShared} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.runtime.VkCyfraRuntime + +class WorkgroupPrimitivesE2eTest extends munit.FunSuite: + + case class TestLayout(output: GBuffer[Int32]) derives Layout + + test("localInvocationIndex returns correct values"): + VkCyfraRuntime.using: + val size = 512 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val localIdx = GIO.localInvocationIndex + GIO.when(idx < size): + GIO.write(layout.output, idx, localIdx) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until size).map(_ % 256).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Local invocation indices mismatch") + + test("workgroupId.x returns correct values"): + VkCyfraRuntime.using: + val size = 512 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val wgId = GIO.workgroupId.x + GIO.when(idx < size): + GIO.write(layout.output, idx, wgId) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until size).map(_ / 256).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Workgroup IDs mismatch") + + test("barrier compiles and executes without error"): + VkCyfraRuntime.using: + val size = 256 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + GIO.write(layout.output, idx, idx) + .flatMap(_ => GIO.barrier) + .flatMap(_ => GIO.pure(layout.output.read(idx))) + .flatMap(value => GIO.write(layout.output, idx, value + 1)) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until size).map(_ + 1).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Barrier test: expected values incremented by 1") + + test("subgroupSize returns a valid value"): + VkCyfraRuntime.using: + val size = 256 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val sgSize = GIO.subgroupSize + GIO.when(idx < size): + GIO.write(layout.output, idx, sgSize) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + assert(resultBuf.forall(_ > 0), s"Subgroup size should be positive") + assert(resultBuf.forall(_ <= 128), s"Subgroup size should be <= 128") + val uniqueValues = resultBuf.distinct + assert(uniqueValues.length == 1, s"All invocations should report the same subgroup size") + + test("shared memory allows workgroup communication".ignore): + VkCyfraRuntime.using: + val workgroupSize = 256 + val shared = GShared[Int32](workgroupSize) + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](workgroupSize)), + dispatchSize = _ => workgroupSize, + ): layout => + val localIdx = GIO.localInvocationIndex + val globalIdx = GIO.invocationId + shared.write(localIdx, globalIdx) + .flatMap(_ => GIO.barrier) + .flatMap: _ => + val reversedIdx: Int32 = (workgroupSize - 1: Int32) - localIdx + val valueFromReversed = shared.read(reversedIdx) + layout.output.write(globalIdx, valueFromReversed) + + val resultBuf = new Array[Int](workgroupSize) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](workgroupSize)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until workgroupSize).map(i => workgroupSize - 1 - i).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Shared memory communication failed") + + test("subgroupAdd reduces values within subgroup"): + VkCyfraRuntime.using: + val size = 256 + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val sum = GIO.subgroupAdd(1: Int32) + GIO.when(idx < size): + GIO.write(layout.output, idx, sum) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val subgroupSizeActual = resultBuf.head + assert(subgroupSizeActual > 0, s"Subgroup sum should be positive") + assert(resultBuf.forall(_ == subgroupSizeActual), s"All lanes should have same subgroup sum (subgroup size)") + + test("subgroupInclusiveAdd computes prefix sums"): + VkCyfraRuntime.using: + val size = 256 + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val prefixSum = GIO.subgroupInclusiveAdd(1: Int32) + GIO.when(idx < size): + GIO.write(layout.output, idx, prefixSum) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val subgroupSize = resultBuf.sliding(2).find { case Array(a, b) => b < a }.map(_(0)).getOrElse(resultBuf.last) + assert(subgroupSize > 0, s"Should detect subgroup size from prefix sums") + + test("subgroupBroadcast broadcasts value from specified lane"): + VkCyfraRuntime.using: + val size = 256 + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val subgroupLaneId = GIO.subgroupLocalInvocationId + val broadcasted = GIO.subgroupBroadcast(subgroupLaneId, 0: Int32) + GIO.when(idx < size): + GIO.write(layout.output, idx, broadcasted) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + assert(resultBuf.forall(_ == 0), s"All lanes should have received broadcast value 0 from lane 0") + + case class FoldTestLayout(input: GBuffer[Float32], output: GBuffer[Float32]) derives Layout + + test("GSeq.fold with subgroupAdd works together"): + VkCyfraRuntime.using: + val size = 256 + val iterations = 4 + + val program = GProgram.static[Unit, FoldTestLayout]( + layout = _ => FoldTestLayout(GBuffer[Float32](size), GBuffer[Float32](size)), + dispatchSize = _ => size, + ): layout => + import io.computenode.cyfra.dsl.collections.GSeq + val idx = GIO.invocationId + val laneId = GIO.subgroupLocalInvocationId + val warpSize = GIO.subgroupSize + + // Each lane computes a partial sum using fold + val partialSum: Float32 = GSeq + .gen[Int32](laneId, _ + warpSize) + .limit(iterations) + .fold(0.0f, (sum: Float32, i: Int32) => { + when(i < size)(sum + GIO.read[Float32](layout.input, i)).otherwise(sum) + }) + + // Then reduce across subgroup + val totalSum: Float32 = GIO.subgroupAdd(partialSum) + + GIO.when(idx < size): + GIO.write(layout.output, idx, totalSum) + + import java.nio.{ByteBuffer, ByteOrder} + val inputBuf = ByteBuffer.allocateDirect(size * 4).order(ByteOrder.nativeOrder()) + inputBuf.asFloatBuffer().put(Array.fill(size)(1.0f)) + inputBuf.rewind() + val resultBuf = new Array[Float](size) + + val region = GBufferRegion + .allocate[FoldTestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = FoldTestLayout( + input = GBuffer[Float32](inputBuf), + output = GBuffer[Float32](size), + ), + onDone = layout => layout.output.readArray(resultBuf), + ) + + // Each invocation should have the sum of the elements it processed + reduced across subgroup + // With iterations=4 and warpSize=32, each lane processes ~4 elements worth of indices + // But with bounds check, only valid indices contribute + assert(resultBuf.forall(_ > 0), s"Total sum should be positive, got ${resultBuf.take(10).mkString(", ")}") diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala index 40430035..9e44d964 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala @@ -163,6 +163,36 @@ object GFunctionExamples: println(s"Saved to examples_output/julia.png") println() + def example4_FibonacciSequence(): Unit = + // Test the Fibonacci-like GSeq from documentation using Vec2[Float32] + // Pattern: GSeq.gen[Vec2[Float32]](init, pair => vec2(pair.y, pair.x + pair.y)) + // Generates: (0,1), (1,1), (1,2), (2,3), (3,5), (5,8), ... + // fib(0), fib(1), fib(2), fib(3), fib(4), fib(5), ... + val fibonacciNth: GFunction[GStruct.Empty, Float32, Float32] = GFunction: _ => + // Generate Fibonacci-like pairs: (a, b) -> (b, a+b) + val fibonacci = GSeq.gen[Vec2[Float32]]((0.0f, 1.0f), pair => (pair.y, pair.x + pair.y)) + // limit(n) gives n pairs, last.x = fib(n-1) + // So limit(11).last.x = fib(10) = 55 + fibonacci.limit(11).lastOr(vec2(0.0f, 0.0f)).x + + val input = Array.fill(256)(0.0f) // dummy input + + println("Example 4: Fibonacci Sequence (GSeq.gen with Vec2)") + println("Testing: GSeq.gen[Vec2[Float32]](vec2(0, 1), pair => vec2(pair.y, pair.x + pair.y))") + println("Computing fib(10) on GPU using limit(11).last.x ...") + + val results: Array[Float] = fibonacciNth.run(input) + + // Sequence with limit(11): (0,1), (1,1), (1,2), (2,3), (3,5), (5,8), (8,13), (13,21), (21,34), (34,55), (55,89) + // last.x = 55 = fib(10) + val expected = 55.0f + println(s"Result: fib(10) = ${results(0).toInt}") + println(s"Expected: ${expected.toInt}") + + val correct = Math.abs(results(0) - expected) < 0.001f + println(s"Result correct: $correct") + println() + case class TransformConfig(scale: Float32, offset: Float32) extends GStruct[TransformConfig] def example8_Uniforms(): Unit = @@ -203,6 +233,7 @@ object GFunctionExamples: example1_HelloGpu() example2_VectorOperations() example3_CustomStructs() + example4_FibonacciSequence() example6_Mandelbrot() example7_JuliaSet() example8_Uniforms() diff --git a/cyfra-llama/compare_incremental.py b/cyfra-llama/compare_incremental.py new file mode 100644 index 00000000..a4848e86 --- /dev/null +++ b/cyfra-llama/compare_incremental.py @@ -0,0 +1,51 @@ +"""Compare llama.cpp predictions for incremental generation.""" +from llama_cpp import Llama +import numpy as np + +# Load model +print("Loading model...") +llm = Llama( + model_path="cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + n_ctx=64, + verbose=False, + logits_all=True, # Enable all logits +) + +def get_top_predictions(llm, tokens, pos): + """Get top predictions for a specific position.""" + # Reset and evaluate + llm.reset() + llm.eval(tokens) + + # Get logits for the requested position + logits = np.array(llm._scores[pos]) + + top_indices = np.argsort(logits)[::-1][:10] + print(f"Top-10 predictions for position {pos}:") + for idx in top_indices: + try: + token_str = llm.detokenize([int(idx)]).decode('utf-8', errors='replace') + token_str = token_str.encode('ascii', errors='replace').decode('ascii') + except: + token_str = f"[{idx}]" + print(f" Token {idx:5d} ({token_str:>10s}): logit={logits[idx]:10.4f}") + + print(f" Stats: min={logits.min():.4f}, max={logits.max():.4f}, mean={logits.mean():.4f}") + return logits + +# Test 1: [BOS, Hello] -> predict next +tokens_1 = [1, 15043] # BOS + Hello +print(f"\n=== Sequence 1: {tokens_1} (BOS + Hello) ===") +logits_1 = get_top_predictions(llm, tokens_1, 1) + +# Test 2: [BOS, Hello, ,] -> predict next +tokens_2 = [1, 15043, 29892] # BOS + Hello + , +print(f"\n=== Sequence 2: {tokens_2} (BOS + Hello + ,) ===") +logits_2 = get_top_predictions(llm, tokens_2, 2) + +# Also compare logits for position 1 in both sequences (should be the same!) +print(f"\n=== Position 1 logits in sequence 2 (should match sequence 1) ===") +llm.reset() +llm.eval(tokens_2) +logits_2_pos1 = np.array(llm._scores[1]) +print(f" max diff between seq1 pos1 and seq2 pos1: {np.abs(logits_1 - logits_2_pos1).max():.6f}") diff --git a/cyfra-llama/compare_logits.py b/cyfra-llama/compare_logits.py new file mode 100644 index 00000000..8613bcec --- /dev/null +++ b/cyfra-llama/compare_logits.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +"""Compare logits from llama-cpp-python with our implementation.""" + +import numpy as np +from llama_cpp import Llama + +def main(): + model_path = "cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + + print(f"Loading model from {model_path}...") + llm = Llama( + model_path=model_path, + n_ctx=32, + n_batch=32, + verbose=True, + logits_all=True, # Get logits for all tokens + ) + + # Test tokens: BOS (1) + "Hello" token + prompt = "Hello" + print(f"\nPrompt: '{prompt}'") + + # Tokenize + tokens = llm.tokenize(prompt.encode(), add_bos=True) + print(f"Tokens: {tokens}") + + # Evaluate and get logits + llm.reset() + llm.eval(tokens) + + # Get logits for the last token + logits = llm.scores[len(tokens) - 1] + logits_array = np.array(logits, dtype=np.float32) + + print(f"\nLogits shape: {logits_array.shape}") + print(f"Logits stats: min={logits_array.min():.4f}, max={logits_array.max():.4f}, mean={logits_array.mean():.4f}, std={logits_array.std():.4f}") + print(f"Logits sum: {logits_array.sum():.4f}") + + # Get top 5 predictions + top_indices = np.argsort(logits_array)[-5:][::-1] + print("\nTop 5 predictions:") + for idx in top_indices: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" {idx}: '{token_str}' (score={logits_array[idx]:.2f})") + + # Print first and last few logits for comparison + print(f"\nFirst 10 logits: {logits_array[:10]}") + print(f"Last 10 logits: {logits_array[-10:]}") + + # Also test with just "Hello" (no BOS) + print("\n" + "="*60) + print("Testing single token (15043 = 'Hello')...") + + llm.reset() + llm.eval([15043]) # Just the "Hello" token + + logits2 = llm.scores[0] + logits2_array = np.array(logits2, dtype=np.float32) + + print(f"Logits stats: min={logits2_array.min():.4f}, max={logits2_array.max():.4f}, mean={logits2_array.mean():.4f}, std={logits2_array.std():.4f}") + + # Get top 5 + top_indices2 = np.argsort(logits2_array)[-5:][::-1] + print("\nTop 5 predictions:") + for idx in top_indices2: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" {idx}: '{token_str}' (score={logits2_array[idx]:.2f})") + +if __name__ == "__main__": + main() diff --git a/cyfra-llama/compare_with_llama_cpp.py b/cyfra-llama/compare_with_llama_cpp.py new file mode 100644 index 00000000..273633e7 --- /dev/null +++ b/cyfra-llama/compare_with_llama_cpp.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +""" +Compare GPU logits against llama.cpp reference. +Run this after running LayerByLayerDebugTest to see the actual llama.cpp output. +""" + +from llama_cpp import Llama +import numpy as np + +MODEL_PATH = "cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + +def main(): + print("Loading model via llama.cpp...") + llm = Llama( + model_path=MODEL_PATH, + n_ctx=512, + n_batch=512, + verbose=False, + logits_all=True, # Get logits for all positions + ) + + # Test token: 15043 = "Hello" + # We want to get the logits for predicting what comes after "Hello" + tokens = [1, 15043] # BOS + "Hello" + + print(f"\nTokens: {tokens}") + print("Running llama.cpp forward pass...") + + # Run forward pass + llm.reset() + llm.eval(tokens) + + # Get logits for last position (predicting what comes after "Hello") + logits = np.array(llm.scores[len(tokens) - 1]) + + print(f"\n=== llama.cpp Reference Logits for token 15043 (Hello) ===") + print(f"Logits shape: {logits.shape}") + print(f"min={logits.min():.4f}, max={logits.max():.4f}, mean={logits.mean():.4f}, std={logits.std():.4f}") + + # Top-10 tokens + top_indices = np.argsort(logits)[::-1][:10] + print("\nTop-10 predicted tokens:") + for idx in top_indices: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" Token {idx:5d} ({token_str:>10s}): logit={logits[idx]:10.4f}") + + # Argmax + predicted = np.argmax(logits) + predicted_str = llm.detokenize([predicted]).decode('utf-8', errors='replace') + print(f"\nPredicted next token: {predicted} ({predicted_str})") + + # Also test T=2 case: "Hello," + print("\n" + "="*60) + print("Testing T=2: [BOS, Hello, ,]") + tokens2 = [1, 15043, 29892] # BOS + "Hello" + "," + + llm.reset() + llm.eval(tokens2) + + logits2 = np.array(llm.scores[len(tokens2) - 1]) + print(f"\n=== llama.cpp Reference Logits for 'Hello,' (predicting 3rd token) ===") + print(f"min={logits2.min():.4f}, max={logits2.max():.4f}, mean={logits2.mean():.4f}, std={logits2.std():.4f}") + + top_indices2 = np.argsort(logits2)[::-1][:10] + print("\nTop-10 predicted tokens:") + for idx in top_indices2: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" Token {idx:5d} ({token_str:>10s}): logit={logits2[idx]:10.4f}") + +if __name__ == "__main__": + main() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala new file mode 100644 index 00000000..f24958f5 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala @@ -0,0 +1,281 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.inference.LlamaInference +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.pipeline.LlamaPipeline +import io.computenode.cyfra.llama.tokenizer.LlamaTokenizer +import io.computenode.cyfra.llama.util.Logger +import io.computenode.cyfra.runtime.VkCyfraRuntime + +import java.nio.file.{Files, Paths} +import scala.io.StdIn + +/** Llama model runner with F16 and F32 pipeline support. + * + * Usage: + * runner --model path/to/model.gguf --type f16 --interactive + * runner --model path/to/model.gguf --type f32 --prompt "Hello world" + */ +object Runner: + + case class Config( + modelPath: String = "", + modelType: String = "auto", // "f16", "f32", or "auto" + interactive: Boolean = false, + measure: Boolean = false, + batch: Boolean = false, // Buffer output, print at end + prompt: Option[String] = None, + maxTokens: Int = 500, + temperature: Float = 0.7f, + topP: Float = 0.9f, + warmupRuns: Int = 3, + benchmarkRuns: Int = 5, + ) + + def main(args: Array[String]): Unit = + val config = parseArgs(args) + + if config.modelPath.isEmpty then + printUsage() + return + + if !Files.exists(Paths.get(config.modelPath)) then + System.err.println(s"Error: Model not found: ${config.modelPath}") + return + + val resolvedType = if config.modelType == "auto" then + if config.modelPath.toLowerCase.contains("f16") then "f16" else "f32" + else config.modelType + + println(s"Cyfra Llama Runner") + println(s"Model: ${config.modelPath}") + println(s"Type: $resolvedType") + + VkCyfraRuntime.using: + val model = LlamaModel.fromGGUF(Paths.get(config.modelPath)) + val tokenizer = LlamaTokenizer(model.gguf) + val useQuantized = resolvedType == "f32" + val inference = new LlamaInference(model, maxT = 1024, useQuantized = useQuantized) + + val pipeline: LlamaPipeline = resolvedType match + case "f16" => inference.getF16KVCachedPipeline + case "f32" => inference.getF32KVCachedPipeline + case _ => + System.err.println(s"Unknown model type: $resolvedType") + return + + println(s"Ready: ${model.config.hiddenSize}d, ${model.config.numHiddenLayers}L\n") + + if config.measure then + runBenchmark(pipeline, tokenizer, config) + else if config.interactive then + runInteractive(pipeline, tokenizer, config) + else if config.prompt.isDefined then + runOnce(pipeline, tokenizer, config.prompt.get, config) + else + printUsage() + + private def runInteractive(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, config: Config): Unit = + println("Interactive mode. Commands: quit, exit") + println("-" * 40) + + var running = true + while running do + print("\nYou: ") + System.out.flush() + val userInput = StdIn.readLine() + + if userInput == null || userInput.trim.toLowerCase == "quit" || userInput.trim.toLowerCase == "exit" then + running = false + else if userInput.trim.nonEmpty then + val prompt = s"<|user|>\n${userInput.trim}\n<|assistant|>\n" + runGeneration(pipeline, tokenizer, prompt, config) + + private def runOnce(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, prompt: String, config: Config): Unit = + println(s"Prompt: $prompt\n") + runGeneration(pipeline, tokenizer, prompt, config) + + private def runBenchmark(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, config: Config): Unit = + val prompt = config.prompt.getOrElse("Once upon a time") + val tokens = tokenizer.encode(prompt) + + println(s"Benchmark: '$prompt' -> ${config.maxTokens} tokens") + println(s"Warmup: ${config.warmupRuns} runs, Benchmark: ${config.benchmarkRuns} runs\n") + + // Greedy argmax sampling + def argmax(logits: Array[Float]): Int = + var maxIdx = 0 + var maxVal = logits(0) + var i = 1 + while i < logits.length do + if logits(i) > maxVal then + maxVal = logits(i) + maxIdx = i + i += 1 + maxIdx + + // Warmup + print("Warming up: ") + for i <- 1 to config.warmupRuns do + pipeline.generate(tokens, config.maxTokens, argmax, _ => (), Set(tokenizer.eosToken), reportStats = false) + print(s"$i ") + System.out.flush() + println("done\n") + + // Benchmark runs + println("Benchmark runs:") + + val (decoded, stats) = (1 to config.benchmarkRuns).map: i => + val generated = pipeline.generate(tokens, config.maxTokens, argmax, _ => (), Set(tokenizer.eosToken), reportStats = false) + val decoded = tokenizer.decode(generated) + val s = pipeline.lastStats.get + println(f" Run $i: ${s.generatedTokens} tokens, generate ${s.decodeTokPerSec}%.1f tok/s") + (decoded, s) + .unzip + + val avgDecode = stats.map(_.decodeTokPerSec).sum / stats.length + val bestDecode = stats.map(_.decodeTokPerSec).max + + println() + println("Last generation:") + println(decoded.last.toString) + println(f"Average: $avgDecode%.1f tok/s") + println(f"Best: $bestDecode%.1f tok/s") + + private def runGeneration(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, prompt: String, config: Config): Unit = + val tokens = tokenizer.encode(prompt) + + if config.batch then + // Batch mode: print at end + val buffer = new StringBuilder() + val generated = pipeline.generate( + promptTokens = tokens, + maxNewTokens = config.maxTokens, + sampleFn = logits => topPSample(logits, config.temperature, config.topP), + onToken = _ => (), + stopTokens = Set(tokenizer.eosToken), + reportStats = false, + ) + val decoded = tokenizer.decode(generated) + println(s"Output: $decoded") + else + // Streaming mode: print tokens as they arrive + print("Output: ") + System.out.flush() + val generated = pipeline.generate( + promptTokens = tokens, + maxNewTokens = config.maxTokens, + sampleFn = logits => topPSample(logits, config.temperature, config.topP), + onToken = token => + val text = tokenizer.decodeToken(token) + if !text.contains("") && !text.contains("<|") then + print(text) + System.out.flush() + , + stopTokens = Set(tokenizer.eosToken), + reportStats = false, + ) + println() + pipeline.lastStats match + case Some(stats) => + println(f"[${generated.length} tokens, generate ${stats.decodeTokPerSec}%.1f tok/s]") + case None => + println(f"[${generated.length} tokens]") + + private def parseArgs(args: Array[String]): Config = + var config = Config() + var i = 0 + while i < args.length do + args(i) match + case "--model" | "-m" if i + 1 < args.length => + config = config.copy(modelPath = args(i + 1)) + i += 2 + case "--type" | "-t" if i + 1 < args.length => + config = config.copy(modelType = args(i + 1).toLowerCase) + i += 2 + case "--interactive" | "-i" => + config = config.copy(interactive = true) + i += 1 + case "--measure" => + config = config.copy(measure = true) + i += 1 + case "--batch" | "-b" => + config = config.copy(batch = true) + i += 1 + case "--warmup" if i + 1 < args.length => + config = config.copy(warmupRuns = args(i + 1).toInt) + i += 2 + case "--runs" if i + 1 < args.length => + config = config.copy(benchmarkRuns = args(i + 1).toInt) + i += 2 + case "--prompt" | "-p" if i + 1 < args.length => + config = config.copy(prompt = Some(args(i + 1))) + i += 2 + case "--max-tokens" | "-n" if i + 1 < args.length => + config = config.copy(maxTokens = args(i + 1).toInt) + i += 2 + case "--temperature" if i + 1 < args.length => + config = config.copy(temperature = args(i + 1).toFloat) + i += 2 + case "--top-p" if i + 1 < args.length => + config = config.copy(topP = args(i + 1).toFloat) + i += 2 + case arg if !arg.startsWith("-") && config.modelPath.isEmpty => + config = config.copy(modelPath = arg) + i += 1 + case other => + System.err.println(s"Unknown argument: $other") + i += 1 + config + + private def printUsage(): Unit = + println(""" + |Usage: runner [OPTIONS] [MODEL_PATH] + | + |Modes: + | -i, --interactive Interactive chat mode (streaming output) + | -p, --prompt TEXT Single prompt and exit (streaming output) + | --measure Benchmark mode (no output, multiple runs) + | + |Options: + | -m, --model PATH Path to GGUF model file + | -t, --type TYPE Model type: f16, f32, or auto (default: auto) + | -b, --batch Buffer output, print at end (faster) + | -n, --max-tokens N Maximum tokens to generate (default: 500) + | --temperature FLOAT Sampling temperature (default: 0.7) + | --top-p FLOAT Top-p sampling threshold (default: 0.9) + | --warmup N Warmup runs for benchmark (default: 3) + | --runs N Benchmark runs (default: 5) + | + |Examples: + | runner -m model.gguf -t f16 -i + | runner -m model.gguf -p "Hello world" -b -n 100 + | runner -m model.gguf --measure -n 128 + |""".stripMargin) + + private def topPSample(logits: Array[Float], temperature: Float, topP: Float): Int = + val scaled = logits.map(_ / temperature) + val maxLogit = scaled.max + val expLogits = scaled.map(x => math.exp(x - maxLogit).toFloat) + val sumExp = expLogits.sum + val probs = expLogits.map(_ / sumExp) + val indexed = probs.zipWithIndex.sortBy(-_._1) + + var cumSum = 0.0f + var cutoffIdx = 0 + while cutoffIdx < indexed.length && cumSum < topP do + cumSum += indexed(cutoffIdx)._1 + cutoffIdx += 1 + + val topTokens = indexed.take(cutoffIdx) + val topSum = topTokens.map(_._1).sum + val normalized = topTokens.map(t => (t._1 / topSum, t._2)) + + val r = scala.util.Random.nextFloat() + var acc = 0.0f + var result = normalized.last._2 + for (prob, idx) <- normalized do + acc += prob + if acc >= r && result == normalized.last._2 then + result = idx + result diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala new file mode 100644 index 00000000..ff3908d0 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala @@ -0,0 +1,204 @@ +package io.computenode.cyfra.llama.gguf + +import java.nio.{ByteBuffer, ByteOrder} + +/** Dequantization functions for GGUF quantized tensors. + * + * Based on llama.cpp's ggml-quants.c + */ +object Dequantize: + + val QK_K = 256 // Block size for K-quants + + /** Convert half-precision float16 to float32. + * + * IEEE 754 half-precision: 1 sign bit, 5 exponent bits, 10 mantissa bits + */ + def fp16ToFp32(h: Short): Float = + val sign = (h >> 15) & 1 + val exp = (h >> 10) & 0x1F + val mant = h & 0x3FF + + if exp == 0 then + // Denormalized or zero + if mant == 0 then + if sign == 1 then -0.0f else 0.0f + else + // Denormalized number + val f = mant.toFloat / 1024.0f + val result = f * math.pow(2, -14).toFloat + if sign == 1 then -result else result + else if exp == 31 then + // Infinity or NaN + if mant == 0 then + if sign == 1 then Float.NegativeInfinity else Float.PositiveInfinity + else + Float.NaN + else + // Normalized number + val f = 1.0f + mant.toFloat / 1024.0f + val result = f * math.pow(2, exp - 15).toFloat + if sign == 1 then -result else result + + /** Dequantize Q4_K block to float32. + * + * Q4_K format: + * - 256 values per block + * - 2x float16 for d and dmin + * - 12 bytes for scales (6-bit each, packed) + * - 128 bytes for quantized values (4-bit each, packed) + * - Total: 144 bytes per block + */ + def dequantizeQ4K(data: Array[Byte], numElements: Long): Array[Float] = + val numBlocks = (numElements / QK_K).toInt + val result = new Array[Float](numElements.toInt) + val buf = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) + + var resultIdx = 0 + for blockIdx <- 0 until numBlocks do + val blockStart = blockIdx * 144 + + // Read d and dmin (fp16) + val dHalf = buf.getShort(blockStart) + val dminHalf = buf.getShort(blockStart + 2) + val d = fp16ToFp32(dHalf) + val dmin = fp16ToFp32(dminHalf) + + // Read scales (12 bytes, 6-bit values packed) + val scales = new Array[Byte](12) + for i <- 0 until 12 do + scales(i) = buf.get(blockStart + 4 + i) + + // Read quantized values (128 bytes, 4-bit packed) + val qs = new Array[Byte](128) + for i <- 0 until 128 do + qs(i) = buf.get(blockStart + 16 + i) + + // Dequantize 256 values in groups of 64 + var is = 0 + var qsIdx = 0 + for j <- 0 until 4 do // 4 groups of 64 + // Get scale and min for this group (two sub-groups of 32) + val (sc1, m1) = getScaleMinK4(is, scales) + val (sc2, m2) = getScaleMinK4(is + 1, scales) + + val d1 = d * sc1 + val m1Val = dmin * m1 + val d2 = d * sc2 + val m2Val = dmin * m2 + + // First 32 values (low nibble) + for l <- 0 until 32 do + val q = qs(qsIdx + l) & 0x0F + result(resultIdx) = d1 * q - m1Val + resultIdx += 1 + + // Second 32 values (high nibble) + for l <- 0 until 32 do + val q = (qs(qsIdx + l) >> 4) & 0x0F + result(resultIdx) = d2 * q - m2Val + resultIdx += 1 + + qsIdx += 32 + is += 2 + + result + + /** Get scale and min from packed 6-bit values in Q4_K scales array. + * + * Matches llama.cpp's get_scale_min_k4 implementation exactly. + * scales array is 12 bytes, j ranges 0-7. + * + * IMPORTANT: Use & 0xFF to convert signed bytes to unsigned before shifting, + * otherwise Java's signed byte extension causes incorrect results when bit 7 is set. + */ + private def getScaleMinK4(j: Int, scales: Array[Byte]): (Float, Float) = + if j < 4 then + // Simple 6-bit extraction from lower bytes + val d = (scales(j) & 0x3F).toFloat + val m = (scales(j + 4) & 0x3F).toFloat + (d, m) + else + // Combine bits from different positions - use & 0xFF for unsigned interpretation + val sj4 = scales(j + 4) & 0xFF // scales[j+4] as unsigned + val sjm4 = scales(j - 4) & 0xFF // scales[j-4] as unsigned + val sj = scales(j) & 0xFF // scales[j] as unsigned + val d = ((sj4 & 0x0F) | ((sjm4 >> 6) << 4)).toFloat + val m = (((sj4 >> 4) & 0x0F) | ((sj >> 6) << 4)).toFloat + (d, m) + + /** Dequantize Q6_K block to float32. + * + * Q6_K format (matches llama.cpp exactly): + * - 256 values per block + * - 128 bytes for low 4 bits (ql) + * - 64 bytes for high 2 bits (qh) + * - 16 bytes for scales (int8) + * - 2 bytes for d (fp16) + * - Total: 210 bytes per block + */ + def dequantizeQ6K(data: Array[Byte], numElements: Long): Array[Float] = + val numBlocks = (numElements / QK_K).toInt + val result = new Array[Float](numElements.toInt) + val buf = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) + + for blockIdx <- 0 until numBlocks do + val blockStart = blockIdx * 210 + val blockResultStart = blockIdx * QK_K + + // Read d (fp16) at offset 208 + val d = fp16ToFp32(buf.getShort(blockStart + 208)) + + // Two halves: n=0 (values 0-127), n=1 (values 128-255) + var qlOffset = 0 + var qhOffset = 0 + var scOffset = 0 + var yOffset = 0 + + for n <- 0 until 2 do + // Process 128 values in this half + for l <- 0 until 32 do + val is = l / 16 // Scale index within this 128-value block + + // Read ql values + val ql0 = buf.get(blockStart + qlOffset + l) & 0xFF + val ql32 = buf.get(blockStart + qlOffset + l + 32) & 0xFF + + // Read qh value + val qhVal = buf.get(blockStart + 128 + qhOffset + l) & 0xFF + + // Read scales (int8, so need sign extension) + val sc0 = buf.get(blockStart + 192 + scOffset + is + 0).toInt + val sc2 = buf.get(blockStart + 192 + scOffset + is + 2).toInt + val sc4 = buf.get(blockStart + 192 + scOffset + is + 4).toInt + val sc6 = buf.get(blockStart + 192 + scOffset + is + 6).toInt + + // Compute 4 quantized values + val q1 = ((ql0 & 0x0F) | (((qhVal >> 0) & 3) << 4)) - 32 + val q2 = ((ql32 & 0x0F) | (((qhVal >> 2) & 3) << 4)) - 32 + val q3 = ((ql0 >> 4) | (((qhVal >> 4) & 3) << 4)) - 32 + val q4 = ((ql32 >> 4) | (((qhVal >> 6) & 3) << 4)) - 32 + + // Store 4 dequantized values + result(blockResultStart + yOffset + l + 0) = d * sc0 * q1 + result(blockResultStart + yOffset + l + 32) = d * sc2 * q2 + result(blockResultStart + yOffset + l + 64) = d * sc4 * q3 + result(blockResultStart + yOffset + l + 96) = d * sc6 * q4 + + // Move to next half + qlOffset += 64 + qhOffset += 32 + scOffset += 8 + yOffset += 128 + + result + + /** Dequantize F16 to F32. */ + def dequantizeF16(data: Array[Byte], numElements: Long): Array[Float] = + val result = new Array[Float](numElements.toInt) + val buf = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) + + for i <- 0 until numElements.toInt do + result(i) = fp16ToFp32(buf.getShort(i * 2)) + + result diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala new file mode 100644 index 00000000..2358a809 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala @@ -0,0 +1,477 @@ +package io.computenode.cyfra.llama.gguf + +import java.io.RandomAccessFile +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.channels.FileChannel +import java.nio.file.Path +import scala.collection.mutable + +/** GGUF (GGML Universal File) format reader. + * + * GGUF is llama.cpp's model format. This reader parses the file header, + * metadata key-value pairs, and tensor information. + * + * File structure: + * - Magic: 4 bytes ("GGUF" = 0x46554747) + * - Version: uint32 (currently 3) + * - Tensor count: uint64 + * - KV count: uint64 + * - Key-value pairs (metadata) + * - Tensor info (name, dimensions, type, offset) + * - Padding to alignment (default 32 bytes) + * - Tensor data + */ +object GGUFReader: + val GGUF_MAGIC: Int = 0x46554747 // "GGUF" + val GGUF_VERSION: Int = 3 + val DEFAULT_ALIGNMENT: Int = 32 + + /** Value types in GGUF metadata. */ + enum ValueType(val id: Int): + case UINT8 extends ValueType(0) + case INT8 extends ValueType(1) + case UINT16 extends ValueType(2) + case INT16 extends ValueType(3) + case UINT32 extends ValueType(4) + case INT32 extends ValueType(5) + case FLOAT32 extends ValueType(6) + case BOOL extends ValueType(7) + case STRING extends ValueType(8) + case ARRAY extends ValueType(9) + case UINT64 extends ValueType(10) + case INT64 extends ValueType(11) + case FLOAT64 extends ValueType(12) + + object ValueType: + def fromId(id: Int): ValueType = + ValueType.values.find(_.id == id).getOrElse( + throw new IllegalArgumentException(s"Unknown value type: $id") + ) + + /** Quantization types for tensors. */ + enum QuantType(val id: Int, val blockSize: Int, val bytesPerBlock: Int): + case F32 extends QuantType(0, 1, 4) + case F16 extends QuantType(1, 1, 2) + case Q4_0 extends QuantType(2, 32, 18) + case Q4_1 extends QuantType(3, 32, 20) + case Q5_0 extends QuantType(6, 32, 22) + case Q5_1 extends QuantType(7, 32, 24) + case Q8_0 extends QuantType(8, 32, 34) + case Q8_1 extends QuantType(9, 32, 36) + case Q2_K extends QuantType(10, 256, 84) + case Q3_K extends QuantType(11, 256, 110) + case Q4_K extends QuantType(12, 256, 144) + case Q5_K extends QuantType(13, 256, 176) + case Q6_K extends QuantType(14, 256, 210) + case Q8_K extends QuantType(15, 256, 292) + case IQ2_XXS extends QuantType(16, 256, 66) + case IQ2_XS extends QuantType(17, 256, 74) + case IQ3_XXS extends QuantType(18, 256, 98) + case IQ1_S extends QuantType(19, 256, 50) + case IQ4_NL extends QuantType(20, 32, 18) + case IQ3_S extends QuantType(21, 256, 110) + case IQ2_S extends QuantType(22, 256, 82) + case IQ4_XS extends QuantType(23, 256, 136) + case BF16 extends QuantType(30, 1, 2) + + object QuantType: + def fromId(id: Int): QuantType = + QuantType.values.find(_.id == id).getOrElse( + throw new IllegalArgumentException(s"Unknown quant type: $id") + ) + + /** Metadata value can be various types. */ + sealed trait MetaValue + case class MetaUInt8(value: Byte) extends MetaValue + case class MetaInt8(value: Byte) extends MetaValue + case class MetaUInt16(value: Short) extends MetaValue + case class MetaInt16(value: Short) extends MetaValue + case class MetaUInt32(value: Int) extends MetaValue + case class MetaInt32(value: Int) extends MetaValue + case class MetaFloat32(value: Float) extends MetaValue + case class MetaBool(value: Boolean) extends MetaValue + case class MetaString(value: String) extends MetaValue + case class MetaUInt64(value: Long) extends MetaValue + case class MetaInt64(value: Long) extends MetaValue + case class MetaFloat64(value: Double) extends MetaValue + case class MetaArray(values: Seq[MetaValue]) extends MetaValue + + /** Tensor information from GGUF file. */ + case class TensorInfo( + name: String, + shape: Array[Long], + quantType: QuantType, + offset: Long, + ): + def numElements: Long = shape.product + def numBytes: Long = + val blocks = (numElements + quantType.blockSize - 1) / quantType.blockSize + blocks * quantType.bytesPerBlock + + /** Parsed GGUF file. */ + case class GGUFFile( + version: Int, + metadata: Map[String, MetaValue], + tensors: Seq[TensorInfo], + dataOffset: Long, + channel: FileChannel, + ): + def close(): Unit = channel.close() + + /** Get metadata value as string. */ + def getString(key: String): Option[String] = metadata.get(key).collect { case MetaString(v) => v } + + /** Get metadata value as int. */ + def getInt(key: String): Option[Int] = metadata.get(key).collect { + case MetaUInt32(v) => v + case MetaInt32(v) => v + case MetaUInt8(v) => v.toInt & 0xFF + case MetaInt8(v) => v.toInt + } + + /** Get metadata value as long. */ + def getLong(key: String): Option[Long] = metadata.get(key).collect { + case MetaUInt64(v) => v + case MetaInt64(v) => v + case MetaUInt32(v) => v.toLong & 0xFFFFFFFFL + case MetaInt32(v) => v.toLong + } + + /** Get metadata value as float. */ + def getFloat(key: String): Option[Float] = metadata.get(key).collect { + case MetaFloat32(v) => v + case MetaFloat64(v) => v.toFloat + } + + /** Get metadata value as string array. */ + def getStringArray(key: String): Option[Array[String]] = metadata.get(key).collect { + case MetaArray(vals) => vals.collect { case MetaString(s) => s }.toArray + } + + /** Get metadata value as float array. */ + def getFloatArray(key: String): Option[Array[Float]] = metadata.get(key).collect { + case MetaArray(vals) => vals.collect { case MetaFloat32(f) => f }.toArray + } + + /** Get tensor by name. */ + def getTensor(name: String): Option[TensorInfo] = tensors.find(_.name == name) + + /** Read tensor data as float array. Only works for F32 tensors. */ + def readTensorF32(tensor: TensorInfo): Array[Float] = + require(tensor.quantType == QuantType.F32, s"Tensor ${tensor.name} is ${tensor.quantType}, not F32") + val buffer = ByteBuffer.allocate(tensor.numBytes.toInt).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buffer, dataOffset + tensor.offset) + buffer.flip() + val result = Array.ofDim[Float](tensor.numElements.toInt) + buffer.asFloatBuffer().get(result) + result + + /** Read tensor data as raw bytes. */ + def readTensorBytes(tensor: TensorInfo): Array[Byte] = + val buffer = ByteBuffer.allocate(tensor.numBytes.toInt) + channel.read(buffer, dataOffset + tensor.offset) + buffer.flip() + val result = Array.ofDim[Byte](tensor.numBytes.toInt) + buffer.get(result) + result + + /** Read tensor data directly into a ByteBuffer for GPU upload. + * Returns little-endian ordered buffer suitable for GBuffer[UInt32]. + */ + def readTensorToBuffer(tensor: TensorInfo): ByteBuffer = + val buffer = ByteBuffer.allocateDirect(tensor.numBytes.toInt).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buffer, dataOffset + tensor.offset) + buffer.rewind() + buffer + + /** Read Q4_K tensor as UInt32 array for GPU upload. + * Q4_K: 144 bytes = 36 UInt32 per 256-element block. + */ + def readTensorQ4KAsUInt32(tensor: TensorInfo): Array[Int] = + require(tensor.quantType == QuantType.Q4_K, s"Tensor ${tensor.name} is ${tensor.quantType}, not Q4_K") + val bytes = readTensorBytes(tensor) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + val numUInt32 = tensor.numBytes.toInt / 4 + val result = Array.ofDim[Int](numUInt32) + buf.asIntBuffer().get(result) + result + + /** Read any quantized tensor as UInt32 array for GPU upload. + * Use this for tensors that may have different quantization types. + */ + def readTensorAsUInt32(tensor: TensorInfo): Array[Int] = + val bytes = readTensorBytes(tensor) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + val numUInt32 = tensor.numBytes.toInt / 4 + val result = Array.ofDim[Int](numUInt32) + buf.asIntBuffer().get(result) + result + + /** Read Q6_K tensor as UInt32 array for GPU upload. + * Q6_K: 210 bytes per 256-element block. + * + * Since 210 is not divisible by 4, we pad each block to 212 bytes (53 uint32) + * for GPU alignment. + */ + def readTensorQ6KAsUInt32(tensor: TensorInfo): Array[Int] = + require(tensor.quantType == QuantType.Q6_K, s"Tensor ${tensor.name} is ${tensor.quantType}, not Q6_K") + val bytes = readTensorBytes(tensor) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + + // Q6_K blocks are 210 bytes. We pack them as-is, reading at byte level on GPU. + // Just pad the total to a multiple of 4 for uint32 alignment. + val numUInt32 = (tensor.numBytes.toInt + 3) / 4 + val result = Array.ofDim[Int](numUInt32) + + // Copy bytes into uint32 array + var i = 0 + while i < tensor.numBytes.toInt / 4 do + result(i) = buf.getInt(i * 4) + i += 1 + + // Handle remaining bytes (if any) + if tensor.numBytes.toInt % 4 != 0 then + var lastWord = 0 + var j = 0 + while j < tensor.numBytes.toInt % 4 do + lastWord |= (bytes(tensor.numBytes.toInt - tensor.numBytes.toInt % 4 + j) & 0xFF) << (j * 8) + j += 1 + result(i) = lastWord + + result + + /** Read and dequantize tensor to Float32. + * + * Supports F32, F16, Q4_K, and Q6_K quantization types. + */ + def readTensorDequantized(tensor: TensorInfo): Array[Float] = + val bytes = readTensorBytes(tensor) + tensor.quantType match + case QuantType.F32 => + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + val result = Array.ofDim[Float](tensor.numElements.toInt) + buf.asFloatBuffer().get(result) + result + case QuantType.F16 => + Dequantize.dequantizeF16(bytes, tensor.numElements) + case QuantType.Q4_K => + Dequantize.dequantizeQ4K(bytes, tensor.numElements) + case QuantType.Q6_K => + Dequantize.dequantizeQ6K(bytes, tensor.numElements) + case other => + throw new UnsupportedOperationException(s"Dequantization not implemented for $other") + + /** Read F16 tensor as raw bytes without conversion. + * + * Returns the raw F16 bytes (2 bytes per element) for direct GPU upload. + * This avoids F32 conversion, saving 2x memory. + */ + def readTensorF16Bytes(tensor: TensorInfo): Array[Byte] = + require(tensor.quantType == QuantType.F16, s"Expected F16 tensor, got ${tensor.quantType}") + readTensorBytes(tensor) + + /** Read GGUF file from path. */ + def read(path: Path): GGUFFile = + val raf = new RandomAccessFile(path.toFile, "r") + val channel = raf.getChannel + + // Read header + val headerBuf = ByteBuffer.allocate(24).order(ByteOrder.LITTLE_ENDIAN) + channel.read(headerBuf, 0) + headerBuf.flip() + + val magic = headerBuf.getInt + if magic != GGUF_MAGIC then + throw new IllegalArgumentException(s"Invalid GGUF magic: ${magic.toHexString}, expected ${GGUF_MAGIC.toHexString}") + + val version = headerBuf.getInt + if version != 2 && version != 3 then + throw new IllegalArgumentException(s"Unsupported GGUF version: $version") + + val tensorCount = headerBuf.getLong + val kvCount = headerBuf.getLong + + // Parse key-value pairs + var offset = 24L + val metadata = mutable.Map[String, MetaValue]() + + for _ <- 0L until kvCount do + val (key, value, newOffset) = readKV(channel, offset) + metadata(key) = value + offset = newOffset + + // Parse tensor info + val tensors = mutable.ArrayBuffer[TensorInfo]() + for _ <- 0L until tensorCount do + val (tensor, newOffset) = readTensorInfo(channel, offset) + tensors += tensor + offset = newOffset + + // Compute data offset with alignment + val alignment = metadata.get("general.alignment").collect { case MetaUInt32(v) => v }.getOrElse(DEFAULT_ALIGNMENT) + val padding = offset % alignment + val dataOffset = if padding == 0 then offset else offset + alignment - padding + + GGUFFile(version, metadata.toMap, tensors.toSeq, dataOffset, channel) + + private def readKV(channel: FileChannel, offset: Long): (String, MetaValue, Long) = + var pos = offset + + // Read key (string) + val (key, keyEndPos) = readString(channel, pos) + pos = keyEndPos + + // Read value type + val typeBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(typeBuf, pos) + typeBuf.flip() + val valueTypeId = typeBuf.getInt + pos += 4 + + val valueType = ValueType.fromId(valueTypeId) + val (value, valueEndPos) = readValue(channel, pos, valueType) + + (key, value, valueEndPos) + + private def readValue(channel: FileChannel, offset: Long, valueType: ValueType): (MetaValue, Long) = + valueType match + case ValueType.UINT8 => + val buf = ByteBuffer.allocate(1) + channel.read(buf, offset) + (MetaUInt8(buf.get(0)), offset + 1) + + case ValueType.INT8 => + val buf = ByteBuffer.allocate(1) + channel.read(buf, offset) + (MetaInt8(buf.get(0)), offset + 1) + + case ValueType.UINT16 => + val buf = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaUInt16(buf.getShort(0)), offset + 2) + + case ValueType.INT16 => + val buf = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaInt16(buf.getShort(0)), offset + 2) + + case ValueType.UINT32 => + val buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaUInt32(buf.getInt(0)), offset + 4) + + case ValueType.INT32 => + val buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaInt32(buf.getInt(0)), offset + 4) + + case ValueType.FLOAT32 => + val buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaFloat32(buf.getFloat(0)), offset + 4) + + case ValueType.BOOL => + val buf = ByteBuffer.allocate(1) + channel.read(buf, offset) + (MetaBool(buf.get(0) != 0), offset + 1) + + case ValueType.STRING => + val (str, endPos) = readString(channel, offset) + (MetaString(str), endPos) + + case ValueType.UINT64 => + val buf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaUInt64(buf.getLong(0)), offset + 8) + + case ValueType.INT64 => + val buf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaInt64(buf.getLong(0)), offset + 8) + + case ValueType.FLOAT64 => + val buf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaFloat64(buf.getDouble(0)), offset + 8) + + case ValueType.ARRAY => + var pos = offset + // Read array element type + val typeBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(typeBuf, pos) + typeBuf.flip() + val elemTypeId = typeBuf.getInt + pos += 4 + + // Read array length + val lenBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(lenBuf, pos) + lenBuf.flip() + val arrayLen = lenBuf.getLong + pos += 8 + + val elemType = ValueType.fromId(elemTypeId) + val values = mutable.ArrayBuffer[MetaValue]() + + for _ <- 0L until arrayLen do + val (value, endPos) = readValue(channel, pos, elemType) + values += value + pos = endPos + + (MetaArray(values.toSeq), pos) + + private def readString(channel: FileChannel, offset: Long): (String, Long) = + // Read string length (uint64) + val lenBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(lenBuf, offset) + lenBuf.flip() + val strLen = lenBuf.getLong.toInt + + // Read string bytes + val strBuf = ByteBuffer.allocate(strLen) + channel.read(strBuf, offset + 8) + strBuf.flip() + val bytes = Array.ofDim[Byte](strLen) + strBuf.get(bytes) + + (new String(bytes, "UTF-8"), offset + 8 + strLen) + + private def readTensorInfo(channel: FileChannel, offset: Long): (TensorInfo, Long) = + var pos = offset + + // Read tensor name + val (name, nameEndPos) = readString(channel, pos) + pos = nameEndPos + + // Read number of dimensions + val dimBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(dimBuf, pos) + dimBuf.flip() + val nDims = dimBuf.getInt + pos += 4 + + // Read dimensions + val shape = Array.ofDim[Long](nDims) + val shapeBuf = ByteBuffer.allocate(8 * nDims).order(ByteOrder.LITTLE_ENDIAN) + channel.read(shapeBuf, pos) + shapeBuf.flip() + for i <- 0 until nDims do + shape(i) = shapeBuf.getLong + pos += 8 * nDims + + // Read quant type + val qtBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(qtBuf, pos) + qtBuf.flip() + val quantTypeId = qtBuf.getInt + pos += 4 + + // Read tensor data offset + val offsetBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(offsetBuf, pos) + offsetBuf.flip() + val tensorOffset = offsetBuf.getLong + pos += 8 + + val quantType = QuantType.fromId(quantTypeId) + (TensorInfo(name, shape, quantType, tensorOffset), pos) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/CPUInference.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/CPUInference.scala new file mode 100644 index 00000000..8afc7cb7 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/CPUInference.scala @@ -0,0 +1,232 @@ +package io.computenode.cyfra.llama.inference + +import io.computenode.cyfra.llama.model.LlamaConfig + +/** CPU-based Llama inference implementation. + * + * This is extracted from LlamaInference to allow F16 pipeline to use + * CPU inference without circular dependencies. + */ +object CPUInference: + + case class LayerWeights( + attnNorm: Array[Float], + wq: Array[Float], + wk: Array[Float], + wv: Array[Float], + wo: Array[Float], + ffnNorm: Array[Float], + ffnGate: Array[Float], + ffnUp: Array[Float], + ffnDown: Array[Float], + ) + + /** Run forward pass on CPU. Returns logits for the last token position. */ + def forwardCPU( + tokens: Array[Int], + tokenEmbed: Array[Float], + layers: Seq[LayerWeights], + outputNorm: Array[Float], + output: Array[Float], + config: LlamaConfig, + ): Array[Float] = + val T = tokens.length + val C = config.hiddenSize + val NH = config.numAttentionHeads + val NKV = config.numKeyValueHeads + val HS = config.headSize + val FFN = config.intermediateSize + val V = config.vocabSize + + // Embedding lookup + var hidden = Array.ofDim[Float](T * C) + for t <- 0 until T do + val tokenId = tokens(t) + System.arraycopy(tokenEmbed, tokenId * C, hidden, t * C, C) + + // Process each layer + for layer <- layers do + val residual = hidden.clone() + + // Attention RMSNorm + hidden = rmsNorm(hidden, layer.attnNorm, T, C, config.rmsNormEps) + + // Q, K, V projections + val q = matmul(hidden, layer.wq, T, C, C) + val k = matmul(hidden, layer.wk, T, C, NKV * HS) + val v = matmul(hidden, layer.wv, T, C, NKV * HS) + + // RoPE + applyRoPE(q, T, NH, HS, config.ropeTheta) + applyRoPE(k, T, NKV, HS, config.ropeTheta) + + // Attention + val attnOut = attention(q, k, v, T, NH, NKV, HS) + + // Output projection + val attnProj = matmul(attnOut, layer.wo, T, C, C) + + // Residual connection + for i <- 0 until T * C do + hidden(i) = residual(i) + attnProj(i) + + val residual2 = hidden.clone() + + // FFN RMSNorm + hidden = rmsNorm(hidden, layer.ffnNorm, T, C, config.rmsNormEps) + + // FFN + val gate = matmul(hidden, layer.ffnGate, T, C, FFN) + val up = matmul(hidden, layer.ffnUp, T, C, FFN) + + // SwiGLU activation + val ffnHidden = Array.ofDim[Float](T * FFN) + for i <- 0 until T * FFN do + val g = gate(i) + val u = up(i) + val silu = g / (1.0f + math.exp(-g).toFloat) + ffnHidden(i) = silu * u + + // FFN down projection + val ffnOut = matmul(ffnHidden, layer.ffnDown, T, FFN, C) + + // Residual connection + for i <- 0 until T * C do + hidden(i) = residual2(i) + ffnOut(i) + + // Final layer norm + hidden = rmsNorm(hidden, outputNorm, T, C, config.rmsNormEps) + + // Output projection - only last token + val lastHidden = hidden.slice((T - 1) * C, T * C) + val logits = Array.ofDim[Float](V) + for i <- 0 until V do + var sum = 0.0f + for j <- 0 until C do + sum += lastHidden(j) * output(i * C + j) + logits(i) = sum + + logits + + /** RMS normalization. */ + private def rmsNorm( + input: Array[Float], + weight: Array[Float], + numRows: Int, + rowSize: Int, + eps: Double, + ): Array[Float] = + val output = Array.ofDim[Float](input.length) + for row <- 0 until numRows do + val offset = row * rowSize + + // Compute RMS + var sumSq = 0.0 + for i <- 0 until rowSize do + val x = input(offset + i) + sumSq += x * x + val rms = math.sqrt(sumSq / rowSize + eps) + val scale = 1.0 / rms + + // Normalize and apply weight + for i <- 0 until rowSize do + output(offset + i) = (input(offset + i) * scale * weight(i)).toFloat + + output + + /** Matrix multiplication: output = input @ weight. */ + private def matmul( + input: Array[Float], + weight: Array[Float], + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + ): Array[Float] = + val output = Array.ofDim[Float](batchSize * outFeatures) + for b <- 0 until batchSize do + for i <- 0 until outFeatures do + var sum = 0.0f + for j <- 0 until inFeatures do + sum += input(b * inFeatures + j) * weight(i * inFeatures + j) + output(b * outFeatures + i) = sum + output + + /** Apply rotary position embeddings (RoPE). */ + private def applyRoPE( + tensor: Array[Float], + seqLen: Int, + numHeads: Int, + headSize: Int, + theta: Double, + ): Unit = + for pos <- 0 until seqLen do + for head <- 0 until numHeads do + val offset = pos * numHeads * headSize + head * headSize + var i = 0 + while i < headSize do + val freq = 1.0 / math.pow(theta, (2 * (i / 2)).toDouble / headSize) + val angle = pos * freq + val cosA = math.cos(angle).toFloat + val sinA = math.sin(angle).toFloat + + val x = tensor(offset + i) + val y = tensor(offset + i + 1) + + tensor(offset + i) = x * cosA - y * sinA + tensor(offset + i + 1) = x * sinA + y * cosA + + i += 2 + + /** Multi-head attention with grouped-query attention (GQA). */ + private def attention( + q: Array[Float], + k: Array[Float], + v: Array[Float], + seqLen: Int, + numHeads: Int, + numKvHeads: Int, + headSize: Int, + ): Array[Float] = + val output = Array.ofDim[Float](seqLen * numHeads * headSize) + val scale = 1.0f / math.sqrt(headSize).toFloat + val gqaRatio = numHeads / numKvHeads + + for pos <- 0 until seqLen do + for head <- 0 until numHeads do + val kvHead = head / gqaRatio + val qOffset = pos * numHeads * headSize + head * headSize + val outOffset = qOffset + + // Compute attention weights for all positions + val scores = Array.ofDim[Float](pos + 1) + for kPos <- 0 to pos do + val kOffset = kPos * numKvHeads * headSize + kvHead * headSize + + // Dot product + var sum = 0.0f + for d <- 0 until headSize do + sum += q(qOffset + d) * k(kOffset + d) + + scores(kPos) = sum * scale + + // Softmax + val maxScore = scores.max + var sumExp = 0.0f + for i <- 0 until scores.length do + scores(i) = math.exp(scores(i) - maxScore).toFloat + sumExp += scores(i) + + for i <- 0 until scores.length do + scores(i) /= sumExp + + // Weighted sum of values + for d <- 0 until headSize do + var sum = 0.0f + for kPos <- 0 to pos do + val vOffset = kPos * numKvHeads * headSize + kvHead * headSize + sum += scores(kPos) * v(vOffset + d) + output(outOffset + d) = sum + + output + +end CPUInference diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala new file mode 100644 index 00000000..01e23aa6 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala @@ -0,0 +1,186 @@ +package io.computenode.cyfra.llama.inference + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.llama.gguf.GGUFReader +import io.computenode.cyfra.llama.gguf.GGUFReader.{QuantType, TensorInfo} +import io.computenode.cyfra.llama.model.{LlamaConfig, LlamaModel} +import io.computenode.cyfra.llama.pipeline.{LlamaF32Pipeline, LlamaF16Pipeline} +import io.computenode.cyfra.llama.util.Logger +import io.computenode.cyfra.runtime.VkCyfraRuntime + +/** Llama inference engine. + * + * Loads weights from GGUF and runs the forward pass on GPU. + * + * Supports two pipeline modes: + * - KVCachedPipeline: For quantized (Q4_K/Q6_K) models + * - F16KVCachedPipeline: For F16-native models (like Llama 3.2) + * + * @param model The loaded Llama model + * @param maxT Maximum sequence length for the pipeline + * @param useQuantized If true, load quantized weights for KVCachedPipeline + */ +class LlamaInference(model: LlamaModel, maxT: Int = 1, useQuantized: Boolean = false)(using runtime: VkCyfraRuntime): + val config: LlamaConfig = model.config + + private lazy val allWeightsAreF16: Boolean = + val f32Tensors = model.gguf.tensors.filter(_.quantType == QuantType.F32) + val f16Tensors = model.gguf.tensors.filter(_.quantType == QuantType.F16) + Logger.debug(s"Model: ${f16Tensors.length} F16, ${f32Tensors.length} F32 tensors") + val hasF16Weights = f16Tensors.exists(t => t.name.contains("weight")) + val f32WeightMatrices = f32Tensors.filter(t => + t.name.contains("attn_q") || t.name.contains("attn_k") || t.name.contains("attn_v") || + t.name.contains("attn_output") || t.name.contains("ffn_gate") || t.name.contains("ffn_up") || + t.name.contains("ffn_down") || t.name.contains("token_embd") || t.name == "output.weight" + ) + hasF16Weights && f32WeightMatrices.isEmpty + + /** Check if weights are a mix of Q4_K and Q6_K (common in TinyLlama). */ + private lazy val hasMixedQuantization: Boolean = + val hasQ4K = model.gguf.tensors.exists(_.quantType == QuantType.Q4_K) + val hasQ6K = model.gguf.tensors.exists(_.quantType == QuantType.Q6_K) + hasQ4K && hasQ6K + + // Mixed quantization weights for KVCachedPipeline + private lazy val mixedQuantWeights = if useQuantized && hasMixedQuantization then Some(loadMixedQuantWeights()) else None + + private lazy val f32KVCachedPipeline: LlamaF32Pipeline.F32KVCachedPipeline = + require(mixedQuantWeights.isDefined, "Mixed quant weights not loaded. Set useQuantized=true.") + new LlamaF32Pipeline.F32KVCachedPipeline(mixedQuantWeights.get, config, maxT) + + // F16-Native Pipeline for F16 models (KV-cached only) + private lazy val f16Weights = if allWeightsAreF16 then Some(loadF16Weights()) else None + + // F16-Native KV Cached Pipeline with Vec4 optimizations (4x weight bandwidth!) + private lazy val f16KVCachedPipeline: LlamaF16Pipeline.F16KVCachedPipeline = + require(f16Weights.isDefined, "F16 KV pipeline requires F16 weights.") + new LlamaF16Pipeline.F16KVCachedPipeline(f16Weights.get, config, maxT) + + /** Get the F32 KV-cached pipeline for quantized models. + * + * Uses Q4_K/Q6_K quantized weights with on-GPU dequantization. + */ + def getF32KVCachedPipeline: LlamaF32Pipeline.F32KVCachedPipeline = + require(useQuantized && hasMixedQuantization, "F32 KV cache requires quantized weights.") + f32KVCachedPipeline + + /** Get the F16-native KV cached pipeline for efficient incremental inference. + * + * This pipeline uses KV caching for O(1) per-token inference: + * - Prefill: Process all prompt tokens at once + * - Decode: Process 1 token at a time, attend to full KV cache + * + * Uses Vec4-optimized matmuls for 4x weight memory bandwidth. + * Requires all weights to be F16 quantized. + * Requires dimensions (C, kvSize, FFN) to be divisible by 4. + */ + def getF16KVCachedPipeline: LlamaF16Pipeline.F16KVCachedPipeline = + require(f16Weights.isDefined, "F16 KV pipeline requires F16 weights. Check that model uses F16 quantization.") + f16KVCachedPipeline + + private def loadMixedQuantWeights(): LlamaF32Pipeline.MixedQuantModelWeights = + Logger.info(s"Loading mixed-quant weights (${config.numHiddenLayers} layers)...") + val startTime = System.currentTimeMillis() + + val tokenEmbed = model.gguf.readTensorDequantized(model.getTensor(LlamaModel.TensorNames.tokenEmbed).get) + val outputNorm = model.gguf.readTensorDequantized(model.getTensor(LlamaModel.TensorNames.outputNorm).get) + val output = model.gguf.readTensorDequantized(model.getTensor(LlamaModel.TensorNames.output).get) + + def readQuantized(tensor: GGUFReader.TensorInfo): (Array[Int], LlamaF32Pipeline.QuantWeightType) = + tensor.quantType match + case QuantType.Q4_K => (model.gguf.readTensorQ4KAsUInt32(tensor), LlamaF32Pipeline.Q4K) + case QuantType.Q6_K => (model.gguf.readTensorQ6KAsUInt32(tensor), LlamaF32Pipeline.Q6K) + case other => throw new IllegalArgumentException(s"Unsupported quantization type: $other") + + val layers = (0 until config.numHiddenLayers).map: l => + val attnNorm = model.gguf.readTensorDequantized(model.getTensor(LlamaModel.TensorNames.attnNorm(l)).get) + val ffnNorm = model.gguf.readTensorDequantized(model.getTensor(LlamaModel.TensorNames.ffnNorm(l)).get) + val (wq, wqType) = readQuantized(model.getTensor(LlamaModel.TensorNames.attnQ(l)).get) + val (wk, wkType) = readQuantized(model.getTensor(LlamaModel.TensorNames.attnK(l)).get) + val (wv, wvType) = readQuantized(model.getTensor(LlamaModel.TensorNames.attnV(l)).get) + val (wo, woType) = readQuantized(model.getTensor(LlamaModel.TensorNames.attnOutput(l)).get) + val (ffnGate, ffnGateType) = readQuantized(model.getTensor(LlamaModel.TensorNames.ffnGate(l)).get) + val (ffnUp, ffnUpType) = readQuantized(model.getTensor(LlamaModel.TensorNames.ffnUp(l)).get) + val (ffnDown, ffnDownType) = readQuantized(model.getTensor(LlamaModel.TensorNames.ffnDown(l)).get) + LlamaF32Pipeline.MixedQuantLayerWeights( + attnNorm = attnNorm, wq = wq, wqType = wqType, wk = wk, wkType = wkType, + wv = wv, wvType = wvType, wo = wo, woType = woType, ffnNorm = ffnNorm, + ffnGate = ffnGate, ffnGateType = ffnGateType, ffnUp = ffnUp, ffnUpType = ffnUpType, + ffnDown = ffnDown, ffnDownType = ffnDownType, + ) + + val elapsed = System.currentTimeMillis() - startTime + val totalMB = layers.map(l => l.wq.length + l.wk.length + l.wv.length + l.wo.length + + l.ffnGate.length + l.ffnUp.length + l.ffnDown.length).sum * 4 / 1024 / 1024 + Logger.info(s"Mixed-quant weights loaded: ${elapsed}ms, ${totalMB}MB") + + LlamaF32Pipeline.MixedQuantModelWeights(tokenEmbed, layers, outputNorm, output) + + /** Read tensor as F16 bytes, converting F32 to F16 if needed. */ + private def readAsF16Bytes(tensor: TensorInfo): Array[Byte] = + if tensor.quantType == QuantType.F16 then + model.gguf.readTensorF16Bytes(tensor) + else if tensor.quantType == QuantType.F32 then + val f32Array = model.gguf.readTensorDequantized(tensor) + val f16Bytes = new Array[Byte](f32Array.length * 2) + val buf = java.nio.ByteBuffer.wrap(f16Bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN) + for (f32Val, idx) <- f32Array.zipWithIndex do + val f16Bits = floatToFloat16Bits(f32Val) + buf.putShort(idx * 2, f16Bits.toShort) + f16Bytes + else + throw new IllegalArgumentException(s"Cannot convert ${tensor.quantType} to F16") + + /** Convert F32 to F16 bits (IEEE 754 half precision). */ + private def floatToFloat16Bits(value: Float): Int = + val bits = java.lang.Float.floatToRawIntBits(value) + val sign = (bits >> 31) & 0x1 + val exp = (bits >> 23) & 0xFF + val frac = bits & 0x7FFFFF + + if exp == 0xFF then + return (sign << 15) | 0x7C00 | (if frac != 0 then 1 else 0) + + if exp == 0 && frac == 0 then + return sign << 15 + + val f16Exp = math.max(0, math.min(31, exp - 127 + 15)) + val f16Frac = frac >> 13 + (sign << 15) | (f16Exp << 10) | f16Frac + + private def loadF16Weights(): LlamaF16Pipeline.F16ModelWeights = + Logger.info(s"Loading F16 weights (${config.numHiddenLayers} layers)...") + val startTime = System.currentTimeMillis() + + val tokenEmbed = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.tokenEmbed).get) + val outputNorm = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.outputNorm).get) + val output = model.getTensor(LlamaModel.TensorNames.output) match + case Some(tensor) => readAsF16Bytes(tensor) + case None => + Logger.debug("Using tied embeddings (no output.weight)") + tokenEmbed + + val layers = (0 until config.numHiddenLayers).map: l => + LlamaF16Pipeline.F16LayerWeights( + attnNorm = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnNorm(l)).get), + wq = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnQ(l)).get), + wk = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnK(l)).get), + wv = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnV(l)).get), + wo = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnOutput(l)).get), + ffnNorm = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnNorm(l)).get), + ffnGate = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnGate(l)).get), + ffnUp = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnUp(l)).get), + ffnDown = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnDown(l)).get), + ) + + val elapsed = System.currentTimeMillis() - startTime + val totalMB = (tokenEmbed.length + outputNorm.length + output.length + + layers.map(l => l.attnNorm.length + l.wq.length + l.wk.length + l.wv.length + l.wo.length + + l.ffnNorm.length + l.ffnGate.length + l.ffnUp.length + l.ffnDown.length).sum) / 1024 / 1024 + Logger.info(s"F16 weights loaded: ${elapsed}ms, ${totalMB}MB") + + LlamaF16Pipeline.F16ModelWeights(tokenEmbed, layers, outputNorm, output) + +object LlamaInference: + def apply(model: LlamaModel)(using runtime: VkCyfraRuntime): LlamaInference = + new LlamaInference(model) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala new file mode 100644 index 00000000..ff80c9bb --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala @@ -0,0 +1,119 @@ +package io.computenode.cyfra.llama.model + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct + +/** Llama model configuration. + * + * Based on the Llama 2 / Llama 3 architecture with: + * - RMSNorm instead of LayerNorm + * - SiLU activation in MLP + * - Rotary Position Embeddings (RoPE) + * - Grouped Query Attention (GQA) + * - SwiGLU MLP structure + */ +case class LlamaConfig( + hiddenSize: Int, // Model dimension (d_model) + intermediateSize: Int, // MLP hidden dimension (usually ~2.7x hidden) + numAttentionHeads: Int, // Query heads + numKeyValueHeads: Int, // Key/Value heads (for GQA) + numHiddenLayers: Int, // Number of transformer blocks + vocabSize: Int, // Vocabulary size + maxPositionEmbeddings: Int, // Max context length + rmsNormEps: Float = 1e-6f, // RMSNorm epsilon + ropeTheta: Float = 10000.0f, // RoPE base frequency + bos_token_id: Int = 1, // Beginning of sequence token + eos_token_id: Int = 2, // End of sequence token +): + def headSize: Int = hiddenSize / numAttentionHeads + def kvHeadSize: Int = hiddenSize / numKeyValueHeads + def gqaRatio: Int = numAttentionHeads / numKeyValueHeads // GQA ratio (1 = MHA, >1 = GQA) + + /** Total parameter count estimate (weights only, no embeddings counted separately) */ + def numParameters: Long = + // Embeddings + val embedParams = vocabSize.toLong * hiddenSize + // Per-layer params + val qkvParams = hiddenSize * (hiddenSize + 2 * (hiddenSize * numKeyValueHeads / numAttentionHeads)) + val outputParams = hiddenSize * hiddenSize + val mlpParams = 3 * hiddenSize * intermediateSize // gate, up, down projections + val normParams = 2 * hiddenSize // 2 RMSNorms per layer + val perLayerParams = qkvParams + outputParams + mlpParams + normParams + // Total + embedParams + numHiddenLayers * perLayerParams + hiddenSize + embedParams // final norm + output proj + +object LlamaConfig: + /** TinyLlama 1.1B configuration */ + val TinyLlama_1B: LlamaConfig = LlamaConfig( + hiddenSize = 2048, + intermediateSize = 5632, + numAttentionHeads = 32, + numKeyValueHeads = 4, + numHiddenLayers = 22, + vocabSize = 32000, + maxPositionEmbeddings = 2048, + ) + + /** Llama 2 7B configuration */ + val Llama2_7B: LlamaConfig = LlamaConfig( + hiddenSize = 4096, + intermediateSize = 11008, + numAttentionHeads = 32, + numKeyValueHeads = 32, // MHA (not GQA) + numHiddenLayers = 32, + vocabSize = 32000, + maxPositionEmbeddings = 4096, + ) + + /** Llama 2 13B configuration */ + val Llama2_13B: LlamaConfig = LlamaConfig( + hiddenSize = 5120, + intermediateSize = 13824, + numAttentionHeads = 40, + numKeyValueHeads = 40, + numHiddenLayers = 40, + vocabSize = 32000, + maxPositionEmbeddings = 4096, + ) + + /** Llama 3 8B configuration */ + val Llama3_8B: LlamaConfig = LlamaConfig( + hiddenSize = 4096, + intermediateSize = 14336, + numAttentionHeads = 32, + numKeyValueHeads = 8, // GQA with ratio 4 + numHiddenLayers = 32, + vocabSize = 128256, + maxPositionEmbeddings = 8192, + ropeTheta = 500000.0f, + ) + +/** GPU-side parameters for Llama operations */ +case class LlamaParams( + B: Int32, // Batch size + T: Int32, // Sequence length (current position for generation) + C: Int32, // Hidden size (channels) + NH: Int32, // Number of attention heads + NKV: Int32, // Number of key-value heads + HS: Int32, // Head size + eps: Float32, // RMSNorm epsilon +) extends GStruct[LlamaParams] + +/** GPU-side parameters for RoPE */ +case class RoPEParams( + headSize: Int32, + maxSeqLen: Int32, + theta: Float32, + position: Int32, // Current position in sequence +) extends GStruct[RoPEParams] + +/** GPU-side parameters for flash attention */ +case class FlashAttnParams( + B: Int32, // Batch size + T: Int32, // Current sequence length + maxT: Int32, // Maximum sequence length (for KV cache) + NH: Int32, // Number of query heads + NKV: Int32, // Number of KV heads + HS: Int32, // Head size + scale: Float32, // 1/sqrt(head_size) +) extends GStruct[FlashAttnParams] diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala new file mode 100644 index 00000000..6c34cc7c --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala @@ -0,0 +1,89 @@ +package io.computenode.cyfra.llama.model + +import io.computenode.cyfra.llama.gguf.GGUFReader +import io.computenode.cyfra.llama.gguf.GGUFReader.* +import io.computenode.cyfra.llama.util.Logger +import java.nio.file.Path + +/** Llama model loaded from GGUF file. + * + * Contains model configuration and weight tensors. + */ +case class LlamaModel( + config: LlamaConfig, + gguf: GGUFFile, +): + /** Get a weight tensor by name. */ + def getTensor(name: String): Option[TensorInfo] = gguf.getTensor(name) + + /** Read a weight tensor as Float32 array (only for F32 tensors). */ + def readWeightF32(name: String): Array[Float] = + gguf.getTensor(name) match + case Some(tensor) => gguf.readTensorF32(tensor) + case None => throw new IllegalArgumentException(s"Tensor not found: $name") + + /** Read raw tensor bytes (for quantized tensors). */ + def readWeightBytes(name: String): Array[Byte] = + gguf.getTensor(name) match + case Some(tensor) => gguf.readTensorBytes(tensor) + case None => throw new IllegalArgumentException(s"Tensor not found: $name") + + /** Close the underlying file. */ + def close(): Unit = gguf.close() + + /** List all tensor names in the model. */ + def tensorNames: Seq[String] = gguf.tensors.map(_.name) + + /** Get model architecture name. */ + def architecture: String = gguf.getString("general.architecture").getOrElse("unknown") + + /** Get model name. */ + def name: String = gguf.getString("general.name").getOrElse("unknown") + + /** Log model info at INFO level. */ + def logInfo(): Unit = + Logger.info(s"Model: $name, arch=$architecture, ${gguf.tensors.size} tensors") + Logger.info(s"Config: ${config.hiddenSize}d, ${config.numHiddenLayers}L, ${config.numAttentionHeads}H, vocab=${config.vocabSize}") + +object LlamaModel: + /** Load Llama model from GGUF file. + * + * Extracts model configuration from GGUF metadata. + */ + def fromGGUF(path: Path): LlamaModel = + val gguf = GGUFReader.read(path) + + // Extract architecture-specific metadata prefix + val arch = gguf.getString("general.architecture").getOrElse("llama") + + // Extract model configuration from metadata + val config = LlamaConfig( + hiddenSize = gguf.getInt(s"$arch.embedding_length").getOrElse(4096), + intermediateSize = gguf.getInt(s"$arch.feed_forward_length").getOrElse(11008), + numAttentionHeads = gguf.getInt(s"$arch.attention.head_count").getOrElse(32), + numKeyValueHeads = gguf.getInt(s"$arch.attention.head_count_kv").getOrElse(32), + numHiddenLayers = gguf.getInt(s"$arch.block_count").getOrElse(32), + vocabSize = gguf.getInt(s"$arch.vocab_size").getOrElse(32000), + maxPositionEmbeddings = gguf.getInt(s"$arch.context_length").getOrElse(2048), + rmsNormEps = gguf.getFloat(s"$arch.attention.layer_norm_rms_epsilon").getOrElse(1e-6f), + ropeTheta = gguf.getFloat(s"$arch.rope.freq_base").getOrElse(10000.0f), + ) + + LlamaModel(config, gguf) + + /** Common Llama tensor name patterns. */ + object TensorNames: + def tokenEmbed: String = "token_embd.weight" + def outputNorm: String = "output_norm.weight" + def output: String = "output.weight" + + def attnNorm(layer: Int): String = s"blk.$layer.attn_norm.weight" + def attnQ(layer: Int): String = s"blk.$layer.attn_q.weight" + def attnK(layer: Int): String = s"blk.$layer.attn_k.weight" + def attnV(layer: Int): String = s"blk.$layer.attn_v.weight" + def attnOutput(layer: Int): String = s"blk.$layer.attn_output.weight" + + def ffnNorm(layer: Int): String = s"blk.$layer.ffn_norm.weight" + def ffnGate(layer: Int): String = s"blk.$layer.ffn_gate.weight" + def ffnUp(layer: Int): String = s"blk.$layer.ffn_up.weight" + def ffnDown(layer: Int): String = s"blk.$layer.ffn_down.weight" diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala new file mode 100644 index 00000000..212c1db4 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala @@ -0,0 +1,865 @@ +package io.computenode.cyfra.llama.pipeline + +import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GCodec, GExecution, GProgram} +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.dsl.given_GStructConstructor_T +import io.computenode.cyfra.llama.model.LlamaConfig +import io.computenode.cyfra.llama.pipeline.PipelineUtils.* +import io.computenode.cyfra.llama.programs.* +import io.computenode.cyfra.llama.programs.f16.* +import io.computenode.cyfra.llama.util.Logger + +import java.nio.{ByteBuffer, ByteOrder} + +/** F16-Native Llama GPU Pipeline - all compute in half precision. + * + * This pipeline keeps everything in F16 for maximum memory efficiency: + * - F16 weights loaded directly from GGUF (no conversion) + * - F16 compute throughout (matmul, attention, FFN) + * - Only final logits in F32 (for softmax stability) + * - 2x memory savings vs F32 pipeline + * + * Follows the same optimized pattern as LlamaF32Pipeline: + * - Single GExecution covering ALL operations + * - All layer parameters concatenated in single buffers + * - Layer offsets computed at compile time per-program + * - ByteBuffer-based I/O for efficiency + * - Only one runUnsafe at the end + */ +object LlamaF16Pipeline: + + // ============= F16 Pipeline Params ============= + + case class F16PipelineParams( + config: LlamaConfig, + B: Int, + T: Int, + startPos: Int = 0, + ): + def C: Int = config.hiddenSize + def NH: Int = config.numAttentionHeads + def NKV: Int = config.numKeyValueHeads + def headSize: Int = config.headSize + def FFN: Int = config.intermediateSize + def V: Int = config.vocabSize + def L: Int = config.numHiddenLayers + def kvSize: Int = NKV * headSize + + // ============= F16 Model Weights ============= + + /** F16 weights structure - keeps F16 bytes, no F32 conversion */ + case class F16LayerWeights( + attnNorm: Array[Byte], // F16 bytes + wq: Array[Byte], // F16 bytes + wk: Array[Byte], // F16 bytes + wv: Array[Byte], // F16 bytes + wo: Array[Byte], // F16 bytes + ffnNorm: Array[Byte], // F16 bytes + ffnGate: Array[Byte], // F16 bytes + ffnUp: Array[Byte], // F16 bytes + ffnDown: Array[Byte], // F16 bytes + ) + + case class F16ModelWeights( + tokenEmbed: Array[Byte], // F16 bytes + layers: Seq[F16LayerWeights], + outputNorm: Array[Byte], // F16 bytes + output: Array[Byte], // F16 bytes (or same as tokenEmbed for tied) + ) + + // ============= F16 KV Cache Pipeline Layout (Vec4 Optimized) ============= + + /** F16 KV Cache Pipeline layout with Vec4 weights for 4x memory bandwidth. + * + * Layout structure follows LlamaF32Pipeline.KVCachePipelineLayout: + * - kCache/vCache: (L, maxSeqLen, NKV, headSize) - persistent across forward calls + * - Other buffers: sized for current T (can be 1 for decode) + */ + case class F16KVCachePipelineLayout( + // Input + tokens: GBuffer[Int32], + + // Scalar weights (embedding lookup, norms need scalar indexing) + tokenEmbed: GBuffer[Float16], // (V, C) - scalar for embedding lookup + attnNorm: GBuffer[Float16], // (L, C) - scalar for RMSNorm + ffnNorm: GBuffer[Float16], // (L, C) - scalar for RMSNorm + outputNorm: GBuffer[Float16], // (C) - scalar for RMSNorm + + // Vec4 weights (matmul - 4x bandwidth!) + wq: GBuffer[Vec4[Float16]], // (L, C, C/4) + wk: GBuffer[Vec4[Float16]], // (L, C, kvSize/4) + wv: GBuffer[Vec4[Float16]], // (L, C, kvSize/4) + wo: GBuffer[Vec4[Float16]], // (L, C, C/4) + ffnGate: GBuffer[Vec4[Float16]], // (L, FFN, C/4) + ffnUp: GBuffer[Vec4[Float16]], // (L, FFN, C/4) + ffnDown: GBuffer[Vec4[Float16]], // (L, C, FFN/4) + outputWeight: GBuffer[Vec4[Float16]], // (V, C/4) + + // KV Cache - F16 + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + + // Activations - scalar F16 (shared across programs) + hidden: GBuffer[Float16], + residual: GBuffer[Float16], + attnNormOut: GBuffer[Float16], + q: GBuffer[Float16], + k: GBuffer[Float16], + v: GBuffer[Float16], + qRoped: GBuffer[Float16], + kRoped: GBuffer[Float16], + attnOut: GBuffer[Float16], + ffnNormOut: GBuffer[Float16], + gate: GBuffer[Float16], + up: GBuffer[Float16], + ffnHidden: GBuffer[Float16], + ffnOut: GBuffer[Float16], + logits: GBuffer[Float32], + + attnParams: GUniform[AttentionParams], + ) derives Layout + + /** F16 Combined layout for generation with Vec4 optimized weights. */ + case class F16GenerationLayout( + // === Scalar weights === + tokenEmbed: GBuffer[Float16], + attnNorm: GBuffer[Float16], + ffnNorm: GBuffer[Float16], + outputNorm: GBuffer[Float16], + + // === Vec4 weights (4x bandwidth!) === + wq: GBuffer[Vec4[Float16]], + wk: GBuffer[Vec4[Float16]], + wv: GBuffer[Vec4[Float16]], + wo: GBuffer[Vec4[Float16]], + ffnGate: GBuffer[Vec4[Float16]], + ffnUp: GBuffer[Vec4[Float16]], + ffnDown: GBuffer[Vec4[Float16]], + outputWeight: GBuffer[Vec4[Float16]], + + // === KV Cache === + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + + // === Prefill I/O === + prefillTokens: GBuffer[Int32], + prefillHidden: GBuffer[Float16], + prefillResidual: GBuffer[Float16], + prefillAttnNormOut: GBuffer[Float16], + prefillQ: GBuffer[Float16], + prefillK: GBuffer[Float16], + prefillV: GBuffer[Float16], + prefillQRoped: GBuffer[Float16], + prefillKRoped: GBuffer[Float16], + prefillAttnOut: GBuffer[Float16], + prefillFfnNormOut: GBuffer[Float16], + prefillGate: GBuffer[Float16], + prefillUp: GBuffer[Float16], + prefillFfnHidden: GBuffer[Float16], + prefillFfnOut: GBuffer[Float16], + prefillLogits: GBuffer[Float32], + prefillAttnParams: GUniform[AttentionParams], + + // === Decode I/O === + decodeToken: GBuffer[Int32], + decodeHidden: GBuffer[Float16], + decodeResidual: GBuffer[Float16], + decodeAttnNormOut: GBuffer[Float16], + decodeQ: GBuffer[Float16], + decodeK: GBuffer[Float16], + decodeV: GBuffer[Float16], + decodeQRoped: GBuffer[Float16], + decodeKRoped: GBuffer[Float16], + decodeAttnOut: GBuffer[Float16], + decodeFfnNormOut: GBuffer[Float16], + decodeGate: GBuffer[Float16], + decodeUp: GBuffer[Float16], + decodeFfnHidden: GBuffer[Float16], + decodeFfnOut: GBuffer[Float16], + decodeLogits: GBuffer[Float32], + decodeAttnParams: GUniform[AttentionParams], + ) derives Layout: + def toPrefillLayout: F16KVCachePipelineLayout = F16KVCachePipelineLayout( + tokens = prefillTokens, + tokenEmbed = tokenEmbed, attnNorm = attnNorm, ffnNorm = ffnNorm, outputNorm = outputNorm, + wq = wq, wk = wk, wv = wv, wo = wo, + ffnGate = ffnGate, ffnUp = ffnUp, ffnDown = ffnDown, outputWeight = outputWeight, + kCache = kCache, vCache = vCache, + hidden = prefillHidden, residual = prefillResidual, attnNormOut = prefillAttnNormOut, + q = prefillQ, k = prefillK, v = prefillV, qRoped = prefillQRoped, kRoped = prefillKRoped, + attnOut = prefillAttnOut, ffnNormOut = prefillFfnNormOut, + gate = prefillGate, up = prefillUp, ffnHidden = prefillFfnHidden, ffnOut = prefillFfnOut, + logits = prefillLogits, attnParams = prefillAttnParams, + ) + + def toDecodeLayout: F16KVCachePipelineLayout = F16KVCachePipelineLayout( + tokens = decodeToken, + tokenEmbed = tokenEmbed, attnNorm = attnNorm, ffnNorm = ffnNorm, outputNorm = outputNorm, + wq = wq, wk = wk, wv = wv, wo = wo, + ffnGate = ffnGate, ffnUp = ffnUp, ffnDown = ffnDown, outputWeight = outputWeight, + kCache = kCache, vCache = vCache, + hidden = decodeHidden, residual = decodeResidual, attnNormOut = decodeAttnNormOut, + q = decodeQ, k = decodeK, v = decodeV, qRoped = decodeQRoped, kRoped = decodeKRoped, + attnOut = decodeAttnOut, ffnNormOut = decodeFfnNormOut, + gate = decodeGate, up = decodeUp, ffnHidden = decodeFfnHidden, ffnOut = decodeFfnOut, + logits = decodeLogits, attnParams = decodeAttnParams, + ) + + + // ============= F16 KV Cached Pipeline Build ============= + + /** Build F16 KV-cached pipeline with Vec4 optimized matmuls. + * + * Uses F16MatmulVecHybridProgram for 4x weight memory bandwidth. + */ + def buildF16KVCachedPipeline( + config: LlamaConfig, + B: Int, + T: Int, + maxSeqLen: Int, + ): GExecution[F16PipelineParams, F16KVCachePipelineLayout, F16KVCachePipelineLayout] = + val C = config.hiddenSize + val NH = config.numAttentionHeads + val NKV = config.numKeyValueHeads + val headSize = config.headSize + val FFN = config.intermediateSize + val V = config.vocabSize + val L = config.numHiddenLayers + val kvSize = NKV * headSize + val eps = config.rmsNormEps.toFloat + val theta = config.ropeTheta.toFloat + val startPos = maxSeqLen - T + + // Verify dimensions are divisible by 4 for Vec4 optimization + require(C % 4 == 0, s"hiddenSize ($C) must be divisible by 4") + require(kvSize % 4 == 0, s"kvSize ($kvSize) must be divisible by 4") + require(FFN % 4 == 0, s"intermediateSize ($FFN) must be divisible by 4") + + // Embedding + val embSizes = F16EmbeddingProgram.Sizes(B * T, C, V) + var pipeline = GExecution[F16PipelineParams, F16KVCachePipelineLayout]() + .addProgram(F16EmbeddingProgram.forward(embSizes))( + _ => embSizes, + l => F16EmbeddingProgram.ProgramLayout(l.tokens, l.tokenEmbed, l.hidden), + ) + + // Process each layer with Vec4 optimized matmuls + for layer <- 0 until L do + val normOffset = layer * C + val ffnNormOffset = layer * C + + // Vec4 weight offsets (in Vec4 units = elements / 4) + val wqOffsetVec4 = layer * C * (C / 4) + val wkOffsetVec4 = layer * C * (kvSize / 4) + val wvOffsetVec4 = layer * C * (kvSize / 4) + val woOffsetVec4 = layer * C * (C / 4) + val ffnGateOffsetVec4 = layer * FFN * (C / 4) + val ffnUpOffsetVec4 = layer * FFN * (C / 4) + val ffnDownOffsetVec4 = layer * C * (FFN / 4) + + val kvCacheLayerOffset = layer * maxSeqLen * kvSize + + val copySizes = F16CopyProgram.Sizes(B * T * C) + val attnNormSizes = F16RMSNormProgram.Sizes(B * T, C, eps, normOffset, L * C) + + // Vec4 matmul sizes + val qSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, C, wqOffsetVec4, L * C * (C / 4)) + val kSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, kvSize, wkOffsetVec4, L * C * (kvSize / 4)) + val vSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, kvSize, wvOffsetVec4, L * C * (kvSize / 4)) + val woSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, C, woOffsetVec4, L * C * (C / 4)) + + val ropeQSizes = F16RoPEProgram.Sizes(B, T, NH, headSize, theta) + val ropeKSizes = F16RoPEProgram.Sizes(B, T, NKV, headSize, theta) + val resSizes = F16ResidualAddProgram.Sizes(B * T * C) + val ffnNormSizes = F16RMSNormProgram.Sizes(B * T, C, eps, ffnNormOffset, L * C) + + // FFN Vec4 matmul sizes + val gateSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, FFN, ffnGateOffsetVec4, L * FFN * (C / 4)) + val upSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, FFN, ffnUpOffsetVec4, L * FFN * (C / 4)) + val swiGluSizes = F16SwiGLUProgram.Sizes(B * T * FFN) + val downSizes = F16MatmulVecHybridProgram.Sizes(B * T, FFN, C, ffnDownOffsetVec4, L * C * (FFN / 4)) + + // Save residual + pipeline = pipeline.addProgram(F16CopyProgram.forward(copySizes))( + _ => copySizes, + l => F16CopyProgram.ProgramLayout(l.hidden, l.residual), + ) + + // Attention norm + pipeline = pipeline.addProgram(F16RMSNormProgram.forward(attnNormSizes))( + _ => attnNormSizes, + l => F16RMSNormProgram.ProgramLayout(l.hidden, l.attnNorm, l.attnNormOut), + ) + + // Q, K, V projections with Vec4 weights + pipeline = pipeline + .addProgram(F16MatmulVecHybridProgram.forward(qSizes))( + _ => qSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.wq, l.attnNormOut, l.q), + ) + .addProgram(F16MatmulVecHybridProgram.forward(kSizes))( + _ => kSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.wk, l.attnNormOut, l.k), + ) + .addProgram(F16MatmulVecHybridProgram.forward(vSizes))( + _ => vSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.wv, l.attnNormOut, l.v), + ) + + // RoPE + pipeline = pipeline + .addProgram(F16RoPEProgram.forward(ropeQSizes))( + _ => ropeQSizes, + l => F16RoPEProgram.ProgramLayout(l.q, l.qRoped, l.attnParams), + ) + .addProgram(F16RoPEProgram.forward(ropeKSizes))( + _ => ropeKSizes, + l => F16RoPEProgram.ProgramLayout(l.k, l.kRoped, l.attnParams), + ) + + // KV Cache Write + val kvWriteKSizes = F16KVCacheWriteK.Sizes(B, T, NKV, headSize, maxSeqLen, layer, startPos, kvCacheLayerOffset, L) + val kvWriteVSizes = F16KVCacheWriteV.Sizes(B, T, NKV, headSize, maxSeqLen, layer, startPos, kvCacheLayerOffset, L) + + pipeline = pipeline + .addProgram(F16KVCacheWriteK.forward(kvWriteKSizes))( + _ => kvWriteKSizes, + l => F16KVCacheWriteK.ProgramLayout(l.kRoped, l.kCache, l.attnParams), + ) + .addProgram(F16KVCacheWriteV.forward(kvWriteVSizes))( + _ => kvWriteVSizes, + l => F16KVCacheWriteV.ProgramLayout(l.v, l.vCache, l.attnParams), + ) + + // Attention + val attnSizes = F16KVCachedAttention.Sizes(B, T, NH, NKV, headSize, startPos, kvCacheLayerOffset, kvCacheLayerOffset, L, maxSeqLen) + pipeline = pipeline.addProgram(F16KVCachedAttention.forward(attnSizes))( + _ => attnSizes, + l => F16KVCachedAttention.ProgramLayout(l.qRoped, l.kCache, l.vCache, l.attnOut, l.attnParams), + ) + + // Output projection with Vec4 weights + pipeline = pipeline + .addProgram(F16MatmulVecHybridProgram.forward(woSizes))( + _ => woSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.wo, l.attnOut, l.hidden), + ) + .addProgram(F16ResidualAddProgram.forward(resSizes))( + _ => resSizes, + l => F16ResidualAddProgram.ProgramLayout(l.residual, l.hidden, l.attnNormOut), + ) + + // FFN with Vec4 weights + pipeline = pipeline + .addProgram(F16CopyProgram.forward(copySizes))( + _ => copySizes, + l => F16CopyProgram.ProgramLayout(l.attnNormOut, l.residual), + ) + .addProgram(F16RMSNormProgram.forward(ffnNormSizes))( + _ => ffnNormSizes, + l => F16RMSNormProgram.ProgramLayout(l.attnNormOut, l.ffnNorm, l.ffnNormOut), + ) + .addProgram(F16MatmulVecHybridProgram.forward(gateSizes))( + _ => gateSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.ffnGate, l.ffnNormOut, l.gate), + ) + .addProgram(F16MatmulVecHybridProgram.forward(upSizes))( + _ => upSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.ffnUp, l.ffnNormOut, l.up), + ) + .addProgram(F16SwiGLUProgram.forward(swiGluSizes))( + _ => swiGluSizes, + l => F16SwiGLUProgram.ProgramLayout(l.gate, l.up, l.ffnHidden), + ) + .addProgram(F16MatmulVecHybridProgram.forward(downSizes))( + _ => downSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.ffnDown, l.ffnHidden, l.ffnOut), + ) + .addProgram(F16ResidualAddProgram.forward(resSizes))( + _ => resSizes, + l => F16ResidualAddProgram.ProgramLayout(l.residual, l.ffnOut, l.hidden), + ) + end for + + // Final norm and output projection with Vec4 weights + val finalNormSizes = F16RMSNormProgram.Sizes(B * T, C, eps, 0, C) + val logitsSizes = F16OutputVec4Program.Sizes(B * T, C, V) + + pipeline + .addProgram(F16RMSNormProgram.forward(finalNormSizes))( + _ => finalNormSizes, + l => F16RMSNormProgram.ProgramLayout(l.hidden, l.outputNorm, l.attnNormOut), + ) + .addProgram(F16OutputVec4Program.forward(logitsSizes))( + _ => logitsSizes, + l => F16OutputVec4Program.ProgramLayout(l.attnNormOut, l.outputWeight, l.logits), + ) + end buildF16KVCachedPipeline + + // ============= F16 KV Cached Pipeline Class ============= + + /** F16 KV-Cached Pipeline for fast incremental inference. + * + * Like LlamaF32Pipeline.F32KVCachedPipeline but all F16: + * - Prefill: Process all prompt tokens at once, fill KV cache from 0 to T-1 + * - Decode: Process 1 token at a time, append to KV cache at seqLen, attend to full cache + * + * This achieves O(1) complexity per generated token (vs O(T) without cache). + * + * Performance measurement: + * - Excludes buffer allocation and weight upload time + * - Measures GPU execution time (after read forces sync) + * - Reports separate prefill and generate tok/s + */ + class F16KVCachedPipeline( + weights: F16ModelWeights, + val config: LlamaConfig, + maxSeqLen: Int = F16KVCachedAttention.MAX_SEQ_LEN, + B: Int = 1, + )(using runtime: CyfraRuntime) extends LlamaPipeline: + require(maxSeqLen <= F16KVCachedAttention.MAX_SEQ_LEN, + s"maxSeqLen=$maxSeqLen exceeds F16KVCachedAttention.MAX_SEQ_LEN=${F16KVCachedAttention.MAX_SEQ_LEN}") + + private val C = config.hiddenSize + private val V = config.vocabSize + private val L = config.numHiddenLayers + private val NH = config.numAttentionHeads + private val NKV = config.numKeyValueHeads + private val headSize = config.headSize + private val FFN = config.intermediateSize + private val kvSize = NKV * headSize + + Logger.info(s"Uploading F16 weights: ${L} layers, ${V}×${C} vocab, maxSeqLen=$maxSeqLen") + + private val tokenEmbedBuf = allocateF16Buffer(V * C) + copyF16BytesToBuffer(weights.tokenEmbed, tokenEmbedBuf) + tokenEmbedBuf.rewind() + + private val attnNormBuf = allocateF16Buffer(L * C) + private val wqBuf = allocateF16Buffer(L * C * C) + private val wkBuf = allocateF16Buffer(L * C * kvSize) + private val wvBuf = allocateF16Buffer(L * C * kvSize) + private val woBuf = allocateF16Buffer(L * C * C) + private val ffnNormBuf = allocateF16Buffer(L * C) + private val ffnGateBuf = allocateF16Buffer(L * FFN * C) + private val ffnUpBuf = allocateF16Buffer(L * FFN * C) + private val ffnDownBuf = allocateF16Buffer(L * C * FFN) + + for (layer, layerIdx) <- weights.layers.zipWithIndex do + copyF16BytesToBuffer(layer.attnNorm, attnNormBuf, layerIdx * C * 2) + copyF16BytesToBuffer(layer.wq, wqBuf, layerIdx * C * C * 2) + copyF16BytesToBuffer(layer.wk, wkBuf, layerIdx * C * kvSize * 2) + copyF16BytesToBuffer(layer.wv, wvBuf, layerIdx * C * kvSize * 2) + copyF16BytesToBuffer(layer.wo, woBuf, layerIdx * C * C * 2) + copyF16BytesToBuffer(layer.ffnNorm, ffnNormBuf, layerIdx * C * 2) + copyF16BytesToBuffer(layer.ffnGate, ffnGateBuf, layerIdx * FFN * C * 2) + copyF16BytesToBuffer(layer.ffnUp, ffnUpBuf, layerIdx * FFN * C * 2) + copyF16BytesToBuffer(layer.ffnDown, ffnDownBuf, layerIdx * C * FFN * 2) + + attnNormBuf.rewind(); wqBuf.rewind(); wkBuf.rewind(); wvBuf.rewind(); woBuf.rewind() + ffnNormBuf.rewind(); ffnGateBuf.rewind(); ffnUpBuf.rewind(); ffnDownBuf.rewind() + + private val outputNormBuf = allocateF16Buffer(C) + copyF16BytesToBuffer(weights.outputNorm, outputNormBuf) + outputNormBuf.rewind() + + private val outputWeightBuf = allocateF16Buffer(V * C) + copyF16BytesToBuffer(weights.output, outputWeightBuf) + outputWeightBuf.rewind() + + Logger.info("F16 weights uploaded to GPU") + + // Pre-allocate decode buffers (reused across generate() calls) + private val decodeTokenBuf = allocateIntBuffer(B * 1) + private val decodeLogitsBuf = allocateF32Buffer(B * 1 * V) + private val decodeLogitsArr = new Array[Float](V) + private val attnParamsBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + + // Pipeline cache to avoid recompilation + private val pipelineCache = scala.collection.mutable.Map[(Int, Int), GExecution[F16PipelineParams, F16KVCachePipelineLayout, F16KVCachePipelineLayout]]() + + private def getOrBuildPipeline(T: Int, seqLen: Int): GExecution[F16PipelineParams, F16KVCachePipelineLayout, F16KVCachePipelineLayout] = + pipelineCache.getOrElseUpdate((T, seqLen), buildF16KVCachedPipeline(config, B, T, maxSeqLen)) + + // Current sequence length (updated after each forward) + private var currentSeqLen: Int = 0 + + /** Current position in sequence (for RoPE and masking). */ + def seqLen: Int = currentSeqLen + + /** Last generation statistics (null if generate() not called yet). */ + private var _lastStats: GenerationStats = null + def lastStats: Option[GenerationStats] = Option(_lastStats) + + /** Generate tokens with KV cache - OPTIMIZED version. + * + * Optimizations: + * - Pre-allocated buffers (no allocation during generation loop) + * - Single runUnsafe to keep KV cache on GPU + * - Early stopping via callback + * - Performance timing excludes setup, measures only GPU execution + * + * Architecture note: We use a single runUnsafe with F16GenerationLayout because: + * 1. KV cache must persist on GPU across decode steps + * 2. Weights should only be uploaded once + * 3. GBufferRegion deallocates on runUnsafe completion + * + * @param promptTokens Input prompt tokens + * @param maxNewTokens Maximum tokens to generate + * @param sampleFn Sampling function (logits => token) + * @param onToken Callback for each generated token + * @param stopTokens Set of tokens that stop generation + * @param reportStats If true, prints performance stats after generation + * @return Array of generated tokens (not including prompt) + */ + def generate( + promptTokens: Array[Int], + maxNewTokens: Int, + sampleFn: Array[Float] => Int, + onToken: Int => Unit = _ => (), + stopTokens: Set[Int] = Set.empty, + reportStats: Boolean = false, + ): Array[Int] = + require(promptTokens.length + maxNewTokens <= maxSeqLen, + s"Total sequence ${promptTokens.length + maxNewTokens} exceeds maxSeqLen=$maxSeqLen") + + currentSeqLen = 0 + val generatedTokens = scala.collection.mutable.ArrayBuffer[Int]() + val prefillT = promptTokens.length + + // Rewind all weight buffers + tokenEmbedBuf.rewind() + attnNormBuf.rewind() + wqBuf.rewind() + wkBuf.rewind() + wvBuf.rewind() + woBuf.rewind() + ffnNormBuf.rewind() + ffnGateBuf.rewind() + ffnUpBuf.rewind() + ffnDownBuf.rewind() + outputNormBuf.rewind() + outputWeightBuf.rewind() + + // Prepare prefill token buffer (this is prompt-size-dependent, allocated each call) + val prefillTokensBuf = allocateIntBuffer(B * prefillT) + prefillTokensBuf.asIntBuffer().put(promptTokens) + prefillTokensBuf.rewind() + val prefillLogitsBuf = allocateF32Buffer(B * prefillT * V) + val prefillLogitsArr = new Array[Float](prefillT * V) + + // Build pipelines (cached) + val prefillPipeline = getOrBuildPipeline(prefillT, prefillT) + val decodePipeline = getOrBuildPipeline(1, maxSeqLen) + val prefillParams = F16PipelineParams(config, B, prefillT, 0) + + // Create attention params buffers + val prefillAttnBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + prefillAttnBuf.putInt(prefillT) // seqLen + prefillAttnBuf.putInt(0) // startPos + prefillAttnBuf.flip() + + val decodeAttnBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + decodeAttnBuf.putInt(1) + decodeAttnBuf.putInt(0) + decodeAttnBuf.flip() + + // Timing accumulators + var prefillStartNs = 0L + var prefillEndNs = 0L + var decodeTimeNs = 0L + var shouldStop = false + + // Build execution region - single runUnsafe keeps KV cache alive + // Prefill phase + var region = GBufferRegion + .allocate[F16GenerationLayout] + .map: layout => + // Start timing just before GPU execution + prefillStartNs = System.nanoTime() + prefillPipeline.execute(prefillParams, layout.toPrefillLayout) + layout + .map: layout => + // Read logits (forces GPU sync) - timing checkpoint + layout.prefillLogits.read(prefillLogitsBuf) + prefillEndNs = System.nanoTime() + prefillLogitsBuf.rewind() + copyFromF32Buffer(prefillLogitsBuf, prefillLogitsArr) + + // Sample first token (NOT timed - this is CPU sampling) + val lastPosLogits = prefillLogitsArr.slice((prefillT - 1) * V, prefillT * V) + val firstToken = sampleFn(lastPosLogits) + generatedTokens += firstToken + onToken(firstToken) + currentSeqLen = prefillT + + // Check stop condition + if stopTokens.contains(firstToken) then + shouldStop = true + else + // Write first decode token + decodeTokenBuf.clear() + decodeTokenBuf.asIntBuffer().put(Array(firstToken)) + decodeTokenBuf.rewind() + layout.decodeToken.write(decodeTokenBuf) + + layout + + // Decode loop - build dynamically with actual early stopping + // We use a loop to add decode steps, checking shouldStop after each + var step = 0 + while step < maxNewTokens - 1 do + val stepIdx = step + val seqLen = prefillT + stepIdx + 1 + val startPos = seqLen - 1 + val decodeParams = F16PipelineParams(config, B, 1, startPos) + + region = region + .map: layout => + if shouldStop then + layout + else + // Update attention params + attnParamsBuf.clear() + attnParamsBuf.putInt(seqLen) + attnParamsBuf.putInt(startPos) + attnParamsBuf.flip() + layout.decodeAttnParams.asInstanceOf[io.computenode.cyfra.dsl.binding.GBinding[AttentionParams]].write(attnParamsBuf, 0) + + // Start decode timing + val stepStartNs = System.nanoTime() + decodePipeline.execute(decodeParams, layout.toDecodeLayout) + + // Read logits (forces GPU sync) + layout.decodeLogits.read(decodeLogitsBuf) + decodeTimeNs += (System.nanoTime() - stepStartNs) + decodeLogitsBuf.rewind() + copyFromF32Buffer(decodeLogitsBuf, decodeLogitsArr) + + // Sample next token (NOT timed) + val nextToken = sampleFn(decodeLogitsArr) + generatedTokens += nextToken + onToken(nextToken) + currentSeqLen += 1 + + // Check stop condition + if stopTokens.contains(nextToken) then + shouldStop = true + else + // Write next token for next iteration + decodeTokenBuf.clear() + decodeTokenBuf.asIntBuffer().put(Array(nextToken)) + decodeTokenBuf.rewind() + layout.decodeToken.write(decodeTokenBuf) + + layout + + step += 1 + end while + + // Execute everything - weights uploaded ONCE, KV cache stays on GPU + region.runUnsafe( + init = F16GenerationLayout( + // Scalar weights (embedding, norms) + tokenEmbed = GBuffer[Float16](tokenEmbedBuf), + attnNorm = GBuffer[Float16](attnNormBuf), + ffnNorm = GBuffer[Float16](ffnNormBuf), + outputNorm = GBuffer[Float16](outputNormBuf), + // Vec4 weights (matmul - same bytes, Vec4 SPIR-V type!) + wq = GBuffer[Vec4[Float16]](wqBuf), + wk = GBuffer[Vec4[Float16]](wkBuf), + wv = GBuffer[Vec4[Float16]](wvBuf), + wo = GBuffer[Vec4[Float16]](woBuf), + ffnGate = GBuffer[Vec4[Float16]](ffnGateBuf), + ffnUp = GBuffer[Vec4[Float16]](ffnUpBuf), + ffnDown = GBuffer[Vec4[Float16]](ffnDownBuf), + outputWeight = GBuffer[Vec4[Float16]](outputWeightBuf), + // KV cache (GPU-only, persists across prefill/decode) + kCache = GBuffer[Float16](L * maxSeqLen * kvSize), + vCache = GBuffer[Float16](L * maxSeqLen * kvSize), + // Prefill buffers (T = promptLen) + prefillTokens = GBuffer[Int32](prefillTokensBuf), + prefillHidden = GBuffer[Float16](B * prefillT * C), + prefillResidual = GBuffer[Float16](B * prefillT * C), + prefillAttnNormOut = GBuffer[Float16](B * prefillT * C), + prefillQ = GBuffer[Float16](B * prefillT * C), + prefillK = GBuffer[Float16](B * prefillT * kvSize), + prefillV = GBuffer[Float16](B * prefillT * kvSize), + prefillQRoped = GBuffer[Float16](B * prefillT * C), + prefillKRoped = GBuffer[Float16](B * prefillT * kvSize), + prefillAttnOut = GBuffer[Float16](B * prefillT * C), + prefillFfnNormOut = GBuffer[Float16](B * prefillT * C), + prefillGate = GBuffer[Float16](B * prefillT * FFN), + prefillUp = GBuffer[Float16](B * prefillT * FFN), + prefillFfnHidden = GBuffer[Float16](B * prefillT * FFN), + prefillFfnOut = GBuffer[Float16](B * prefillT * C), + prefillLogits = GBuffer[Float32](prefillLogitsBuf), + prefillAttnParams = GUniform[AttentionParams](prefillAttnBuf), + // Decode buffers (T = 1) - pre-allocated, reused + decodeToken = GBuffer[Int32](decodeTokenBuf), + decodeHidden = GBuffer[Float16](B * 1 * C), + decodeResidual = GBuffer[Float16](B * 1 * C), + decodeAttnNormOut = GBuffer[Float16](B * 1 * C), + decodeQ = GBuffer[Float16](B * 1 * C), + decodeK = GBuffer[Float16](B * 1 * kvSize), + decodeV = GBuffer[Float16](B * 1 * kvSize), + decodeQRoped = GBuffer[Float16](B * 1 * C), + decodeKRoped = GBuffer[Float16](B * 1 * kvSize), + decodeAttnOut = GBuffer[Float16](B * 1 * C), + decodeFfnNormOut = GBuffer[Float16](B * 1 * C), + decodeGate = GBuffer[Float16](B * 1 * FFN), + decodeUp = GBuffer[Float16](B * 1 * FFN), + decodeFfnHidden = GBuffer[Float16](B * 1 * FFN), + decodeFfnOut = GBuffer[Float16](B * 1 * C), + decodeLogits = GBuffer[Float32](decodeLogitsBuf), + decodeAttnParams = GUniform[AttentionParams](decodeAttnBuf), + ), + onDone = _ => (), + ) + + val prefillTimeMs = (prefillEndNs - prefillStartNs) / 1_000_000.0 + val decodeTimeMs = decodeTimeNs / 1_000_000.0 + val totalTimeMs = prefillTimeMs + decodeTimeMs + + _lastStats = GenerationStats( + promptTokens = prefillT, + generatedTokens = generatedTokens.length, + prefillTimeMs = prefillTimeMs, + decodeTimeMs = decodeTimeMs, + totalTimeMs = totalTimeMs, + ) + + if reportStats then + Logger.info(_lastStats.toString) + + generatedTokens.toArray + end generate + + /** Legacy API: Prefill (for backward compatibility). + * + * NOTE: This creates a new GPU allocation each call - use `generate()` for efficient inference. + */ + def prefill(tokens: Array[Int]): Array[Float] = + require(tokens.length <= maxSeqLen, s"Prompt length ${tokens.length} exceeds maxSeqLen=$maxSeqLen") + currentSeqLen = 0 + val logits = forwardWithKVCache(tokens, startPos = 0) + currentSeqLen = tokens.length + val lastPosLogits = new Array[Float](V) + System.arraycopy(logits, (tokens.length - 1) * V, lastPosLogits, 0, V) + lastPosLogits + + /** Legacy API: Decode single token (for backward compatibility). + * + * NOTE: This creates a new GPU allocation each call - use `generate()` for efficient inference. + */ + def decode(token: Int): Array[Float] = + require(currentSeqLen < maxSeqLen, s"Sequence length reached maxSeqLen=$maxSeqLen") + val logits = forwardWithKVCache(Array(token), startPos = currentSeqLen) + currentSeqLen += 1 + logits + + // Legacy KV cache buffers for prefill/decode API + private lazy val legacyKCacheBuf = allocateF16Buffer(L * maxSeqLen * kvSize) + private lazy val legacyVCacheBuf = allocateF16Buffer(L * maxSeqLen * kvSize) + + // Forward pass with KV cache (legacy API - has GPU-CPU roundtrip) + private def forwardWithKVCache(tokens: Array[Int], startPos: Int): Array[Float] = + val T = tokens.length + val seqLen = startPos + T // Total sequence length after this forward + + // Rewind weight buffers + tokenEmbedBuf.rewind() + attnNormBuf.rewind() + wqBuf.rewind() + wkBuf.rewind() + wvBuf.rewind() + woBuf.rewind() + ffnNormBuf.rewind() + ffnGateBuf.rewind() + ffnUpBuf.rewind() + ffnDownBuf.rewind() + outputNormBuf.rewind() + outputWeightBuf.rewind() + legacyKCacheBuf.rewind() + legacyVCacheBuf.rewind() + + // Token buffer + val tokensBuf = allocateIntBuffer(B * T) + tokensBuf.asIntBuffer().put(tokens) + tokensBuf.rewind() + + // Get cached pipeline for this T and seqLen + val pipeline = getOrBuildPipeline(T, seqLen) + + val pParams = F16PipelineParams(config, B, T, startPos) + + val logits = new Array[Float](B * T * V) + val logitsBuf = allocateF32Buffer(B * T * V) + + // Create AttentionParams ByteBuffer (seqLen + startPos) + val attnStartPos = seqLen - T + val attnParamsBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + attnParamsBuf.putInt(seqLen) + attnParamsBuf.putInt(attnStartPos) + attnParamsBuf.flip() + + val region = GBufferRegion + .allocate[F16KVCachePipelineLayout] + .map(layout => pipeline.execute(pParams, layout)) + + region.runUnsafe( + init = F16KVCachePipelineLayout( + tokens = GBuffer[Int32](tokensBuf), + // Scalar weights + tokenEmbed = GBuffer[Float16](tokenEmbedBuf), + attnNorm = GBuffer[Float16](attnNormBuf), + ffnNorm = GBuffer[Float16](ffnNormBuf), + outputNorm = GBuffer[Float16](outputNormBuf), + // Vec4 weights + wq = GBuffer[Vec4[Float16]](wqBuf), + wk = GBuffer[Vec4[Float16]](wkBuf), + wv = GBuffer[Vec4[Float16]](wvBuf), + wo = GBuffer[Vec4[Float16]](woBuf), + ffnGate = GBuffer[Vec4[Float16]](ffnGateBuf), + ffnUp = GBuffer[Vec4[Float16]](ffnUpBuf), + ffnDown = GBuffer[Vec4[Float16]](ffnDownBuf), + outputWeight = GBuffer[Vec4[Float16]](outputWeightBuf), + // KV cache + kCache = GBuffer[Float16](legacyKCacheBuf), + vCache = GBuffer[Float16](legacyVCacheBuf), + // Activations + hidden = GBuffer[Float16](B * T * C), + residual = GBuffer[Float16](B * T * C), + attnNormOut = GBuffer[Float16](B * T * C), + q = GBuffer[Float16](B * T * C), + k = GBuffer[Float16](B * T * kvSize), + v = GBuffer[Float16](B * T * kvSize), + qRoped = GBuffer[Float16](B * T * C), + kRoped = GBuffer[Float16](B * T * kvSize), + attnOut = GBuffer[Float16](B * T * C), + ffnNormOut = GBuffer[Float16](B * T * C), + gate = GBuffer[Float16](B * T * FFN), + up = GBuffer[Float16](B * T * FFN), + ffnHidden = GBuffer[Float16](B * T * FFN), + ffnOut = GBuffer[Float16](B * T * C), + logits = GBuffer[Float32](logitsBuf), + attnParams = GUniform[AttentionParams](attnParamsBuf), + ), + onDone = layout => + // Read back KV cache (for next iteration) - THIS IS THE ROUNDTRIP + layout.kCache.read(legacyKCacheBuf) + layout.vCache.read(legacyVCacheBuf) + layout.logits.read(logitsBuf) + copyFromF32Buffer(logitsBuf, logits), + ) + + logits + end forwardWithKVCache + end F16KVCachedPipeline + +end LlamaF16Pipeline diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF32Pipeline.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF32Pipeline.scala new file mode 100644 index 00000000..a7894e6f --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF32Pipeline.scala @@ -0,0 +1,848 @@ +package io.computenode.cyfra.llama.pipeline + +import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GCodec, GExecution, GProgram} +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.GShared +import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.dsl.given_GStructConstructor_T +import io.computenode.cyfra.llama.model.LlamaConfig +import io.computenode.cyfra.llama.pipeline.PipelineUtils.* +import io.computenode.cyfra.llama.programs.AttentionParams +import io.computenode.cyfra.llama.programs.f32.* +import io.computenode.cyfra.llama.util.Logger + +import java.nio.{ByteBuffer, ByteOrder} + +/** F32/Quantized Llama GPU Pipeline. + * + * Supports quantized weights (Q4_K, Q6_K) with on-GPU dequantization. + * Uses F32 activations for compute precision. + * + * Pattern: + * - Single GExecution covering ALL operations + * - All layer parameters concatenated in single buffers + * - Layer offsets computed at compile time per-program + * - ByteBuffer-based I/O for efficiency + */ +object LlamaF32Pipeline: + + // ============= Params ============= + + /** Unified params struct for Llama programs. */ + case class LlamaParams( + B: Int32, // batch size + T: Int32, // sequence length + C: Int32, // hidden size + NH: Int32, // number of query heads + NKV: Int32, // number of key-value heads + headSize: Int32, // head size + FFN: Int32, // intermediate (FFN) size + V: Int32, // vocab size + L: Int32, // total layers + eps: Float32, // RMSNorm epsilon + ropeTheta: Float32, // RoPE theta + startPos: Int32, // RoPE starting position + ) extends GStruct[LlamaParams] + + object LlamaParams: + def create(config: LlamaConfig, T: Int, B: Int = 1, startPos: Int = 0): LlamaParams = + LlamaParams( + B = B, + T = T, + C = config.hiddenSize, + NH = config.numAttentionHeads, + NKV = config.numKeyValueHeads, + headSize = config.headSize, + FFN = config.intermediateSize, + V = config.vocabSize, + L = config.numHiddenLayers, + eps = config.rmsNormEps.toFloat, + ropeTheta = config.ropeTheta.toFloat, + startPos = startPos, + ) + + case class PipelineParams( + config: LlamaConfig, + B: Int, + T: Int, + startPos: Int = 0, + ) + + // ============= Weight Types ============= + + /** Quantization type indicator for weights. */ + sealed trait QuantWeightType + case object Q4K extends QuantWeightType + case object Q6K extends QuantWeightType + + /** Mixed quantized layer weights - supports both Q4_K and Q6_K. + * TinyLlama uses Q4_K for most weights but Q6_K for V tensors. + */ + case class MixedQuantLayerWeights( + attnNorm: Array[Float], // F32 - norms stay F32 + wq: Array[Int], // Quantized (Q4_K or Q6_K) + wqType: QuantWeightType, + wk: Array[Int], // Quantized + wkType: QuantWeightType, + wv: Array[Int], // Quantized (often Q6_K in TinyLlama) + wvType: QuantWeightType, + wo: Array[Int], // Quantized + woType: QuantWeightType, + ffnNorm: Array[Float], // F32 - norms stay F32 + ffnGate: Array[Int], // Quantized + ffnGateType: QuantWeightType, + ffnUp: Array[Int], // Quantized + ffnUpType: QuantWeightType, + ffnDown: Array[Int], // Quantized + ffnDownType: QuantWeightType, + ) + + case class MixedQuantModelWeights( + tokenEmbed: Array[Float], // F32 - embeddings stay F32 + layers: Seq[MixedQuantLayerWeights], + outputNorm: Array[Float], // F32 + output: Array[Float], // F32 - output head stays F32 + ) + + + /** Calculate Q4_K size in uint32 for a weight matrix. */ + def q4kSizeUint32(outFeatures: Int, inFeatures: Int): Int = + val numQBlocks = inFeatures / Q4KMatmulVecProgram.QK_K + outFeatures * numQBlocks * Q4KMatmulVecProgram.UINT32_PER_BLOCK + + /** Calculate Q6_K size in bytes for a weight matrix. */ + def q6kSizeBytes(outFeatures: Int, inFeatures: Int): Int = + val numQBlocks = inFeatures / Q6KMatmulVecProgram.QK_K + outFeatures * numQBlocks * Q6KMatmulVecProgram.BLOCK_BYTES + + // ============= KV-Cached Pipeline Layout ============= + + /** Pipeline layout with KV cache for incremental generation. + * + * The KV cache stores K,V vectors for all previous positions, enabling O(1) decode. + * Layout structure follows llama.cpp: + * - kCache/vCache: (L, maxSeqLen, NKV, headSize) - persistent across forward calls + * - Other buffers: sized for current T (can be 1 for decode) + */ + case class KVCachePipelineLayout( + // Input + tokens: GBuffer[Int32], + + // Parameters - F32 (embeddings, norms) + tokenEmbed: GBuffer[Float32], + attnNorm: GBuffer[Float32], + ffnNorm: GBuffer[Float32], + outputNorm: GBuffer[Float32], + outputWeight: GBuffer[Float32], + + // Parameters - Quantized (Q4_K/Q6_K) + wq: GBuffer[UInt32], + wk: GBuffer[UInt32], + wo: GBuffer[UInt32], + ffnGate: GBuffer[UInt32], + ffnUp: GBuffer[UInt32], + ffnDown: GBuffer[UInt32], + wv: GBuffer[UInt32], // May be Q6_K + + // KV Cache - persistent across calls (L * maxSeqLen * kvSize) + kCache: GBuffer[Float32], // (L, maxSeqLen, NKV, headSize) + vCache: GBuffer[Float32], // (L, maxSeqLen, NKV, headSize) + + // Activations - sized for current T + hidden: GBuffer[Float32], + residual: GBuffer[Float32], + attnNormOut: GBuffer[Float32], + q: GBuffer[Float32], + k: GBuffer[Float32], + v: GBuffer[Float32], + qRoped: GBuffer[Float32], + kRoped: GBuffer[Float32], + attnOut: GBuffer[Float32], + ffnNormOut: GBuffer[Float32], + gate: GBuffer[Float32], + up: GBuffer[Float32], + ffnHidden: GBuffer[Float32], + ffnOut: GBuffer[Float32], + logits: GBuffer[Float32], + + // Params uniforms + params: GUniform[LlamaParams], + attnParams: GUniform[AttentionParams], // Runtime seqLen for attention + ) derives Layout + + /** Combined layout for efficient generation - contains BOTH prefill and decode buffers. + * + * This allows a single runUnsafe where: + * - Weights are uploaded ONCE + * - KV cache stays on GPU + * - Prefill uses prefill-sized activations + * - Decode uses decode-sized activations + */ + case class GenerationLayout( + // === Shared: Weights (uploaded once, read-only) === + tokenEmbed: GBuffer[Float32], + attnNorm: GBuffer[Float32], + ffnNorm: GBuffer[Float32], + outputNorm: GBuffer[Float32], + outputWeight: GBuffer[Float32], + wq: GBuffer[UInt32], + wk: GBuffer[UInt32], + wo: GBuffer[UInt32], + ffnGate: GBuffer[UInt32], + ffnUp: GBuffer[UInt32], + ffnDown: GBuffer[UInt32], + wv: GBuffer[UInt32], + + // === Shared: KV Cache (persists across prefill and decode) === + kCache: GBuffer[Float32], + vCache: GBuffer[Float32], + + // === Prefill I/O and activations (T = promptLen) === + prefillTokens: GBuffer[Int32], + prefillHidden: GBuffer[Float32], + prefillResidual: GBuffer[Float32], + prefillAttnNormOut: GBuffer[Float32], + prefillQ: GBuffer[Float32], + prefillK: GBuffer[Float32], + prefillV: GBuffer[Float32], + prefillQRoped: GBuffer[Float32], + prefillKRoped: GBuffer[Float32], + prefillAttnOut: GBuffer[Float32], + prefillFfnNormOut: GBuffer[Float32], + prefillGate: GBuffer[Float32], + prefillUp: GBuffer[Float32], + prefillFfnHidden: GBuffer[Float32], + prefillFfnOut: GBuffer[Float32], + prefillLogits: GBuffer[Float32], + prefillParams: GUniform[LlamaParams], + prefillAttnParams: GUniform[AttentionParams], + + // === Decode I/O and activations (T = 1) === + decodeToken: GBuffer[Int32], + decodeHidden: GBuffer[Float32], + decodeResidual: GBuffer[Float32], + decodeAttnNormOut: GBuffer[Float32], + decodeQ: GBuffer[Float32], + decodeK: GBuffer[Float32], + decodeV: GBuffer[Float32], + decodeQRoped: GBuffer[Float32], + decodeKRoped: GBuffer[Float32], + decodeAttnOut: GBuffer[Float32], + decodeFfnNormOut: GBuffer[Float32], + decodeGate: GBuffer[Float32], + decodeUp: GBuffer[Float32], + decodeFfnHidden: GBuffer[Float32], + decodeFfnOut: GBuffer[Float32], + decodeLogits: GBuffer[Float32], + decodeParams: GUniform[LlamaParams], + decodeAttnParams: GUniform[AttentionParams], + ) derives Layout: + /** Map to KVCachePipelineLayout for prefill */ + def toPrefillLayout: KVCachePipelineLayout = KVCachePipelineLayout( + tokens = prefillTokens, + tokenEmbed = tokenEmbed, attnNorm = attnNorm, ffnNorm = ffnNorm, + outputNorm = outputNorm, outputWeight = outputWeight, + wq = wq, wk = wk, wo = wo, ffnGate = ffnGate, ffnUp = ffnUp, ffnDown = ffnDown, wv = wv, + kCache = kCache, vCache = vCache, + hidden = prefillHidden, residual = prefillResidual, attnNormOut = prefillAttnNormOut, + q = prefillQ, k = prefillK, v = prefillV, qRoped = prefillQRoped, kRoped = prefillKRoped, + attnOut = prefillAttnOut, ffnNormOut = prefillFfnNormOut, + gate = prefillGate, up = prefillUp, ffnHidden = prefillFfnHidden, ffnOut = prefillFfnOut, + logits = prefillLogits, params = prefillParams, attnParams = prefillAttnParams, + ) + + /** Map to KVCachePipelineLayout for decode */ + def toDecodeLayout: KVCachePipelineLayout = KVCachePipelineLayout( + tokens = decodeToken, + tokenEmbed = tokenEmbed, attnNorm = attnNorm, ffnNorm = ffnNorm, + outputNorm = outputNorm, outputWeight = outputWeight, + wq = wq, wk = wk, wo = wo, ffnGate = ffnGate, ffnUp = ffnUp, ffnDown = ffnDown, wv = wv, + kCache = kCache, vCache = vCache, + hidden = decodeHidden, residual = decodeResidual, attnNormOut = decodeAttnNormOut, + q = decodeQ, k = decodeK, v = decodeV, qRoped = decodeQRoped, kRoped = decodeKRoped, + attnOut = decodeAttnOut, ffnNormOut = decodeFfnNormOut, + gate = decodeGate, up = decodeUp, ffnHidden = decodeFfnHidden, ffnOut = decodeFfnOut, + logits = decodeLogits, params = decodeParams, attnParams = decodeAttnParams, + ) + + // ============= KV-Cached Pipeline Class ============= + + /** KV-Cached Pipeline for fast incremental inference. + * + * Like llama.cpp, maintains persistent KV cache buffers: + * - Prefill: Process all prompt tokens at once, fill KV cache from 0 to T-1 + * - Decode: Process 1 token at a time, append to KV cache at seqLen, attend to full cache + * + * This achieves O(1) complexity per generated token (vs O(T) without cache). + */ + class F32KVCachedPipeline( + weights: MixedQuantModelWeights, + val config: LlamaConfig, + maxSeqLen: Int = KVCachedAttention.MAX_SEQ_LEN, + B: Int = 1, + )(using runtime: CyfraRuntime) extends LlamaPipeline: + require(maxSeqLen <= KVCachedAttention.MAX_SEQ_LEN, + s"maxSeqLen=$maxSeqLen exceeds KVCachedAttention.MAX_SEQ_LEN=${KVCachedAttention.MAX_SEQ_LEN}") + + private val C = config.hiddenSize + private val NH = config.numAttentionHeads + private val NKV = config.numKeyValueHeads + private val headSize = config.headSize + private val FFN = config.intermediateSize + private val V = config.vocabSize + private val L = config.numHiddenLayers + private val kvSize = NKV * headSize + + // Q4_K sizes per tensor (in uint32) + private val wqUint32PerLayer = q4kSizeUint32(C, C) + private val wkUint32PerLayer = q4kSizeUint32(kvSize, C) + private val woUint32PerLayer = q4kSizeUint32(C, C) + private val ffnGateUint32PerLayer = q4kSizeUint32(FFN, C) + private val ffnUpUint32PerLayer = q4kSizeUint32(FFN, C) + + // Per-layer weight type info from loaded weights + private val vTypes: Seq[QuantWeightType] = weights.layers.map(_.wvType) + private val downTypes: Seq[QuantWeightType] = weights.layers.map(_.ffnDownType) + + // Per-layer sizes for V and down (varies by quant type) + private val wvQ4kUint32 = q4kSizeUint32(kvSize, C) + private val wvQ6kBytes = q6kSizeBytes(kvSize, C) + private val wvQ6kUint32 = (wvQ6kBytes + 3) / 4 + + private val downQ4kUint32 = q4kSizeUint32(C, FFN) + private val downQ6kBytes = q6kSizeBytes(C, FFN) + private val downQ6kUint32 = (downQ6kBytes + 3) / 4 + + // Calculate cumulative offsets for mixed-type buffers + private val vOffsets: Seq[Int] = weights.layers.scanLeft(0) { (offset, layer) => + offset + (if layer.wvType == Q6K then wvQ6kUint32 else wvQ4kUint32) + }.dropRight(1) + private val totalVUint32 = vOffsets.lastOption.getOrElse(0) + + (if vTypes.lastOption.contains(Q6K) then wvQ6kUint32 else wvQ4kUint32) + + private val downOffsets: Seq[Int] = weights.layers.scanLeft(0) { (offset, layer) => + offset + (if layer.ffnDownType == Q6K then downQ6kUint32 else downQ4kUint32) + }.dropRight(1) + private val totalDownUint32 = downOffsets.lastOption.getOrElse(0) + + (if downTypes.lastOption.contains(Q6K) then downQ6kUint32 else downQ4kUint32) + + Logger.info(s"Uploading F32/quantized weights: ${L} layers, ${V}×${C} vocab, maxSeqLen=$maxSeqLen") + + private val tokenEmbedBuf = allocateF32Buffer(V * C); copyToF32Buffer(weights.tokenEmbed, tokenEmbedBuf) + private val attnNormBuf = allocateF32Buffer(L * C) + private val ffnNormBuf = allocateF32Buffer(L * C) + private val outputNormBuf = allocateF32Buffer(C); copyToF32Buffer(weights.outputNorm, outputNormBuf) + private val outputWeightBuf = allocateF32Buffer(V * C); copyToF32Buffer(weights.output, outputWeightBuf) + + // Fill norm buffers + for (layer, layerIdx) <- weights.layers.zipWithIndex do + attnNormBuf.position(layerIdx * C * 4).asFloatBuffer().put(layer.attnNorm) + ffnNormBuf.position(layerIdx * C * 4).asFloatBuffer().put(layer.ffnNorm) + attnNormBuf.rewind(); ffnNormBuf.rewind() + + // Pre-allocate quantized weight buffers + private val wqBuf = allocateIntBuffer(L * wqUint32PerLayer) + private val wkBuf = allocateIntBuffer(L * wkUint32PerLayer) + private val woBuf = allocateIntBuffer(L * woUint32PerLayer) + private val ffnGateBuf = allocateIntBuffer(L * ffnGateUint32PerLayer) + private val ffnUpBuf = allocateIntBuffer(L * ffnUpUint32PerLayer) + private val wvBuf = allocateIntBuffer(totalVUint32) + private val ffnDownBuf = allocateIntBuffer(totalDownUint32) + + // Fill weight buffers + for (layer, layerIdx) <- weights.layers.zipWithIndex do + wqBuf.position(layerIdx * wqUint32PerLayer * 4).asIntBuffer().put(layer.wq) + wkBuf.position(layerIdx * wkUint32PerLayer * 4).asIntBuffer().put(layer.wk) + woBuf.position(layerIdx * woUint32PerLayer * 4).asIntBuffer().put(layer.wo) + ffnGateBuf.position(layerIdx * ffnGateUint32PerLayer * 4).asIntBuffer().put(layer.ffnGate) + ffnUpBuf.position(layerIdx * ffnUpUint32PerLayer * 4).asIntBuffer().put(layer.ffnUp) + wvBuf.position(vOffsets(layerIdx) * 4).asIntBuffer().put(layer.wv) + ffnDownBuf.position(downOffsets(layerIdx) * 4).asIntBuffer().put(layer.ffnDown) + + wqBuf.rewind(); wkBuf.rewind(); woBuf.rewind() + ffnGateBuf.rewind(); ffnUpBuf.rewind() + wvBuf.rewind(); ffnDownBuf.rewind() + + // Pipeline cache to avoid recompilation + private val pipelineCache = scala.collection.mutable.Map[(Int, Int), GExecution[PipelineParams, KVCachePipelineLayout, KVCachePipelineLayout]]() + + private def getOrBuildPipeline(T: Int, seqLen: Int): GExecution[PipelineParams, KVCachePipelineLayout, KVCachePipelineLayout] = + pipelineCache.getOrElseUpdate((T, seqLen), buildKVCachedPipeline(T, seqLen)) + + // Current sequence length (updated after each forward) + private var currentSeqLen: Int = 0 + + /** Current position in sequence (for RoPE and masking). */ + def seqLen: Int = currentSeqLen + + /** Generate tokens with KV cache - EFFICIENT version using single runUnsafe. + * + * Uses GenerationLayout pattern with separate prefill/decode buffers. + * Weights are uploaded ONCE, KV cache stays on GPU, only small I/O per token. + */ + def generate( + promptTokens: Array[Int], + maxNewTokens: Int, + sampleFn: Array[Float] => Int, + onToken: Int => Unit = _ => (), + stopTokens: Set[Int] = Set.empty, + reportStats: Boolean = false, + ): Array[Int] = + require(promptTokens.length + maxNewTokens <= maxSeqLen, + s"Total sequence ${promptTokens.length + maxNewTokens} exceeds maxSeqLen=$maxSeqLen") + + currentSeqLen = 0 + val generatedTokens = scala.collection.mutable.ArrayBuffer[Int]() + val prefillT = promptTokens.length + + // Pre-allocate I/O buffers + val prefillTokensBuf = allocateIntBuffer(B * prefillT) + prefillTokensBuf.asIntBuffer().put(promptTokens) + prefillTokensBuf.rewind() + val prefillLogitsBuf = allocateF32Buffer(B * prefillT * V) + + val decodeTokenBuf = allocateIntBuffer(B * 1) + val decodeLogitsBuf = allocateF32Buffer(B * 1 * V) + + // Build pipelines + val prefillPipeline = getOrBuildPipeline(prefillT, prefillT) + val prefillParams = PipelineParams(config, B, prefillT, 0) + + // Build region with GenerationLayout - prefill first, then decode steps + var region = GBufferRegion + .allocate[GenerationLayout] + .map: layout => + prefillPipeline.execute(prefillParams, layout.toPrefillLayout) + layout + .map: layout => + layout.prefillLogits.read(prefillLogitsBuf) + prefillLogitsBuf.rewind() + + val logitsArr = new Array[Float](prefillT * V) + copyFromF32Buffer(prefillLogitsBuf, logitsArr) + val lastPosLogits = logitsArr.slice((prefillT - 1) * V, prefillT * V) + val firstToken = sampleFn(lastPosLogits) + generatedTokens += firstToken + onToken(firstToken) + currentSeqLen = prefillT + + if !stopTokens.contains(firstToken) && firstToken != 2 then + decodeTokenBuf.clear() + decodeTokenBuf.asIntBuffer().put(Array(firstToken)) + decodeTokenBuf.rewind() + layout.decodeToken.write(decodeTokenBuf) + + layout + + var shouldStop = false + val decodePipeline = getOrBuildPipeline(1, maxSeqLen) + val attnParamsBuf = java.nio.ByteBuffer.allocateDirect(8).order(java.nio.ByteOrder.nativeOrder()) + + region = (0 until maxNewTokens - 1).foldLeft(region): (regionAcc, step) => + val seqLen = prefillT + step + 1 + val startPos = seqLen - 1 + val decodeParams = PipelineParams(config, B, 1, startPos) + + regionAcc + .map: layout => + if !shouldStop then + attnParamsBuf.clear() + attnParamsBuf.putInt(seqLen) + attnParamsBuf.putInt(startPos) + attnParamsBuf.flip() + val uniform: io.computenode.cyfra.dsl.binding.GBinding[?] = layout.decodeAttnParams + uniform.write(attnParamsBuf, 0) + decodePipeline.execute(decodeParams, layout.toDecodeLayout) + layout + .map: layout => + if shouldStop then + layout + else + layout.decodeLogits.read(decodeLogitsBuf) + decodeLogitsBuf.rewind() + + val logitsArr = new Array[Float](V) + copyFromF32Buffer(decodeLogitsBuf, logitsArr) + val nextToken = sampleFn(logitsArr) + + generatedTokens += nextToken + onToken(nextToken) + currentSeqLen += 1 + + if stopTokens.contains(nextToken) || nextToken == 2 then + shouldStop = true + else + decodeTokenBuf.clear() + decodeTokenBuf.asIntBuffer().put(Array(nextToken)) + decodeTokenBuf.rewind() + layout.decodeToken.write(decodeTokenBuf) + + layout + + val gPrefillParams = LlamaParams.create(config, prefillT, B, 0) + val gDecodeParams = LlamaParams.create(config, 1, B, 0) + + def createAttnParamsBuffer(seqLen: Int, startPos: Int): ByteBuffer = + val buf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + buf.putInt(seqLen) + buf.putInt(startPos) + buf.flip() + buf + + val prefillAttnBuf = createAttnParamsBuffer(prefillT, 0) + val decodeAttnBuf = createAttnParamsBuffer(1, 0) + + region.runUnsafe( + init = GenerationLayout( + tokenEmbed = GBuffer[Float32](tokenEmbedBuf), + attnNorm = GBuffer[Float32](attnNormBuf), + ffnNorm = GBuffer[Float32](ffnNormBuf), + outputNorm = GBuffer[Float32](outputNormBuf), + outputWeight = GBuffer[Float32](outputWeightBuf), + wq = GBuffer[UInt32](wqBuf), + wk = GBuffer[UInt32](wkBuf), + wo = GBuffer[UInt32](woBuf), + ffnGate = GBuffer[UInt32](ffnGateBuf), + ffnUp = GBuffer[UInt32](ffnUpBuf), + ffnDown = GBuffer[UInt32](ffnDownBuf), + wv = GBuffer[UInt32](wvBuf), + kCache = GBuffer[Float32](L * maxSeqLen * kvSize), + vCache = GBuffer[Float32](L * maxSeqLen * kvSize), + prefillTokens = GBuffer[Int32](prefillTokensBuf), + prefillHidden = GBuffer[Float32](B * prefillT * C), + prefillResidual = GBuffer[Float32](B * prefillT * C), + prefillAttnNormOut = GBuffer[Float32](B * prefillT * C), + prefillQ = GBuffer[Float32](B * prefillT * C), + prefillK = GBuffer[Float32](B * prefillT * kvSize), + prefillV = GBuffer[Float32](B * prefillT * kvSize), + prefillQRoped = GBuffer[Float32](B * prefillT * C), + prefillKRoped = GBuffer[Float32](B * prefillT * kvSize), + prefillAttnOut = GBuffer[Float32](B * prefillT * C), + prefillFfnNormOut = GBuffer[Float32](B * prefillT * C), + prefillGate = GBuffer[Float32](B * prefillT * FFN), + prefillUp = GBuffer[Float32](B * prefillT * FFN), + prefillFfnHidden = GBuffer[Float32](B * prefillT * FFN), + prefillFfnOut = GBuffer[Float32](B * prefillT * C), + prefillLogits = GBuffer[Float32](prefillLogitsBuf), + prefillParams = GUniform(gPrefillParams), + prefillAttnParams = GUniform[AttentionParams](prefillAttnBuf), + decodeToken = GBuffer[Int32](decodeTokenBuf), + decodeHidden = GBuffer[Float32](B * 1 * C), + decodeResidual = GBuffer[Float32](B * 1 * C), + decodeAttnNormOut = GBuffer[Float32](B * 1 * C), + decodeQ = GBuffer[Float32](B * 1 * C), + decodeK = GBuffer[Float32](B * 1 * kvSize), + decodeV = GBuffer[Float32](B * 1 * kvSize), + decodeQRoped = GBuffer[Float32](B * 1 * C), + decodeKRoped = GBuffer[Float32](B * 1 * kvSize), + decodeAttnOut = GBuffer[Float32](B * 1 * C), + decodeFfnNormOut = GBuffer[Float32](B * 1 * C), + decodeGate = GBuffer[Float32](B * 1 * FFN), + decodeUp = GBuffer[Float32](B * 1 * FFN), + decodeFfnHidden = GBuffer[Float32](B * 1 * FFN), + decodeFfnOut = GBuffer[Float32](B * 1 * C), + decodeLogits = GBuffer[Float32](decodeLogitsBuf), + decodeParams = GUniform(gDecodeParams), + decodeAttnParams = GUniform[AttentionParams](decodeAttnBuf), + ), + onDone = _ => (), + ) + + // Basic stats (F32 pipeline doesn't have fine-grained timing) + if reportStats then + _lastStats = GenerationStats( + promptTokens = promptTokens.length, + generatedTokens = generatedTokens.length, + prefillTimeMs = 0, + decodeTimeMs = 0, + totalTimeMs = 0, + ) + Logger.info(_lastStats.toString) + + generatedTokens.toArray + + /** Legacy API: Prefill (for backward compatibility with tests). */ + def prefill(tokens: Array[Int]): Array[Float] = + require(tokens.length <= maxSeqLen, s"Prompt length ${tokens.length} exceeds maxSeqLen=$maxSeqLen") + currentSeqLen = 0 + val logits = forwardWithKVCache(tokens, startPos = 0) + currentSeqLen = tokens.length + val lastPosLogits = new Array[Float](V) + System.arraycopy(logits, (tokens.length - 1) * V, lastPosLogits, 0, V) + lastPosLogits + + /** Legacy API: Decode single token (for backward compatibility with tests). */ + def decode(token: Int): Array[Float] = + require(currentSeqLen < maxSeqLen, s"Sequence length reached maxSeqLen=$maxSeqLen") + val logits = forwardWithKVCache(Array(token), startPos = currentSeqLen) + currentSeqLen += 1 + logits + + // Build KV-cached pipeline for given T and seqLen + private def buildKVCachedPipeline(T: Int, seqLen: Int): GExecution[PipelineParams, KVCachePipelineLayout, KVCachePipelineLayout] = + val eps = config.rmsNormEps.toFloat + val theta = config.ropeTheta.toFloat + val startPos = seqLen - T + + val embSizes = EmbeddingProgram.Sizes(B * T, C, V) + var pipeline = GExecution[PipelineParams, KVCachePipelineLayout]() + .addProgram(EmbeddingProgram.forward(embSizes))( + _ => embSizes, + l => EmbeddingProgram.ProgramLayout(l.tokens, l.tokenEmbed, l.hidden), + ) + + for layer <- 0 until L do + val normOffset = layer * C + val wqOffset = layer * wqUint32PerLayer + val wkOffset = layer * wkUint32PerLayer + val woOffset = layer * woUint32PerLayer + val ffnGateOffset = layer * ffnGateUint32PerLayer + val ffnUpOffset = layer * ffnUpUint32PerLayer + + val vOffsetUint32 = vOffsets(layer) + val vIsQ6K = vTypes(layer) == Q6K + val downOffsetUint32 = downOffsets(layer) + val downIsQ6K = downTypes(layer) == Q6K + + val kvCacheLayerOffset = layer * maxSeqLen * kvSize + + val copySizes = CopyProgram.Sizes(B * T * C) + val attnNormSizes = RMSNormProgram.Sizes(B * T, C, eps, normOffset, L * C) + val qSizes = Q4KMatmulLayered.Sizes(B * T, C, C, wqOffset, L * wqUint32PerLayer) + val kSizes = Q4KMatmulLayered.Sizes(B * T, C, kvSize, wkOffset, L * wkUint32PerLayer) + val ropeQSizes = RoPEProgram.Sizes(B, T, NH, headSize, theta, startPos) + val ropeKSizes = RoPEProgram.Sizes(B, T, NKV, headSize, theta, startPos) + val woSizes = Q4KMatmulLayered.Sizes(B * T, C, C, woOffset, L * woUint32PerLayer) + val resSizes = ResidualAddProgram.Sizes(B * T * C) + val ffnNormSizes = RMSNormProgram.Sizes(B * T, C, eps, normOffset, L * C) + val gateSizes = Q4KMatmulLayered.Sizes(B * T, C, FFN, ffnGateOffset, L * ffnGateUint32PerLayer) + val upSizes = Q4KMatmulLayered.Sizes(B * T, C, FFN, ffnUpOffset, L * ffnUpUint32PerLayer) + val swiGluSizes = SwiGLUProgram.Sizes(B * T * FFN) + + pipeline = pipeline.addProgram(CopyProgram.forward(copySizes))( + _ => copySizes, + l => CopyProgram.ProgramLayout(l.hidden, l.residual), + ) + + pipeline = pipeline.addProgram(RMSNormProgram.forward(attnNormSizes))( + _ => attnNormSizes, + l => RMSNormProgram.ProgramLayout(l.hidden, l.attnNorm, l.attnNormOut), + ) + + pipeline = pipeline + .addProgram(Q4KMatmulLayered.forward(qSizes))( + _ => qSizes, + l => Q4KMatmulLayered.ProgramLayout(l.wq, l.attnNormOut, l.q), + ) + .addProgram(Q4KMatmulLayered.forward(kSizes))( + _ => kSizes, + l => Q4KMatmulLayered.ProgramLayout(l.wk, l.attnNormOut, l.k), + ) + + if vIsQ6K then + val vQ6kOffset = vOffsetUint32 * 4 + val vSizes = Q6KMatmulLayered.Sizes(B * T, C, kvSize, vQ6kOffset, totalVUint32 * 4) + pipeline = pipeline.addProgram(Q6KMatmulLayered.forward(vSizes))( + _ => vSizes, + l => Q6KMatmulLayered.ProgramLayout(l.wv, l.attnNormOut, l.v), + ) + else + val vSizes = Q4KMatmulLayered.Sizes(B * T, C, kvSize, vOffsetUint32, totalVUint32) + pipeline = pipeline.addProgram(Q4KMatmulLayered.forward(vSizes))( + _ => vSizes, + l => Q4KMatmulLayered.ProgramLayout(l.wv, l.attnNormOut, l.v), + ) + + pipeline = pipeline + .addProgram(RoPEProgram.forward(ropeQSizes))( + _ => ropeQSizes, + l => RoPEProgram.ProgramLayout(l.q, l.qRoped, l.attnParams), + ) + .addProgram(RoPEProgram.forward(ropeKSizes))( + _ => ropeKSizes, + l => RoPEProgram.ProgramLayout(l.k, l.kRoped, l.attnParams), + ) + + val kvWriteKSizes = KVCacheWriteK.Sizes(B, T, NKV, headSize, maxSeqLen, layer, startPos, kvCacheLayerOffset, L) + val kvWriteVSizes = KVCacheWriteV.Sizes(B, T, NKV, headSize, maxSeqLen, layer, startPos, kvCacheLayerOffset, L) + + pipeline = pipeline + .addProgram(KVCacheWriteK.forward(kvWriteKSizes))( + _ => kvWriteKSizes, + l => KVCacheWriteK.ProgramLayout(l.kRoped, l.kCache, l.attnParams), + ) + .addProgram(KVCacheWriteV.forward(kvWriteVSizes))( + _ => kvWriteVSizes, + l => KVCacheWriteV.ProgramLayout(l.v, l.vCache, l.attnParams), + ) + + val attnSizes = KVCachedAttention.Sizes(B, T, NH, NKV, headSize, startPos, kvCacheLayerOffset, kvCacheLayerOffset, L, maxSeqLen) + pipeline = pipeline.addProgram(KVCachedAttention.forward(attnSizes))( + _ => attnSizes, + l => KVCachedAttention.ProgramLayout(l.qRoped, l.kCache, l.vCache, l.attnOut, l.attnParams), + ) + + pipeline = pipeline + .addProgram(Q4KMatmulLayered.forward(woSizes))( + _ => woSizes, + l => Q4KMatmulLayered.ProgramLayout(l.wo, l.attnOut, l.hidden), + ) + .addProgram(ResidualAddProgram.forward(resSizes))( + _ => resSizes, + l => ResidualAddProgram.ProgramLayout(l.residual, l.hidden, l.attnNormOut), + ) + + pipeline = pipeline + .addProgram(CopyProgram.forward(copySizes))( + _ => copySizes, + l => CopyProgram.ProgramLayout(l.attnNormOut, l.residual), + ) + .addProgram(RMSNormProgram.forward(ffnNormSizes))( + _ => ffnNormSizes, + l => RMSNormProgram.ProgramLayout(l.attnNormOut, l.ffnNorm, l.ffnNormOut), + ) + .addProgram(Q4KMatmulLayered.forward(gateSizes))( + _ => gateSizes, + l => Q4KMatmulLayered.ProgramLayout(l.ffnGate, l.ffnNormOut, l.gate), + ) + .addProgram(Q4KMatmulLayered.forward(upSizes))( + _ => upSizes, + l => Q4KMatmulLayered.ProgramLayout(l.ffnUp, l.ffnNormOut, l.up), + ) + .addProgram(SwiGLUProgram.forward(swiGluSizes))( + _ => swiGluSizes, + l => SwiGLUProgram.ProgramLayout(l.gate, l.up, l.ffnHidden), + ) + + if downIsQ6K then + val downQ6kOffset = downOffsetUint32 * 4 + val downSizes = Q6KMatmulLayered.Sizes(B * T, FFN, C, downQ6kOffset, totalDownUint32 * 4) + pipeline = pipeline.addProgram(Q6KMatmulLayered.forward(downSizes))( + _ => downSizes, + l => Q6KMatmulLayered.ProgramLayout(l.ffnDown, l.ffnHidden, l.ffnOut), + ) + else + val downSizes = Q4KMatmulLayered.Sizes(B * T, FFN, C, downOffsetUint32, totalDownUint32) + pipeline = pipeline.addProgram(Q4KMatmulLayered.forward(downSizes))( + _ => downSizes, + l => Q4KMatmulLayered.ProgramLayout(l.ffnDown, l.ffnHidden, l.ffnOut), + ) + + pipeline = pipeline.addProgram(ResidualAddProgram.forward(resSizes))( + _ => resSizes, + l => ResidualAddProgram.ProgramLayout(l.residual, l.ffnOut, l.hidden), + ) + end for + + val finalNormSizes = RMSNormProgram.Sizes(B * T, C, eps, 0, C) + val logitsSizes = TiledMatmulVecProgram.Sizes(B * T, C, V, 0, V * C) + + pipeline + .addProgram(RMSNormProgram.forward(finalNormSizes))( + _ => finalNormSizes, + l => RMSNormProgram.ProgramLayout(l.hidden, l.outputNorm, l.attnNormOut), + ) + .addProgram(TiledMatmulVecProgram.forward(logitsSizes))( + _ => logitsSizes, + l => TiledMatmulVecProgram.ProgramLayout(l.outputWeight, l.attnNormOut, l.logits), + ) + end buildKVCachedPipeline + + // Legacy KV cache buffers for prefill/decode API + private lazy val legacyKCacheBuf = allocateF32Buffer(L * maxSeqLen * kvSize) + private lazy val legacyVCacheBuf = allocateF32Buffer(L * maxSeqLen * kvSize) + + private def forwardWithKVCache(tokens: Array[Int], startPos: Int): Array[Float] = + val T = tokens.length + val seqLen = startPos + T + + tokenEmbedBuf.rewind() + attnNormBuf.rewind() + ffnNormBuf.rewind() + outputNormBuf.rewind() + outputWeightBuf.rewind() + wqBuf.rewind() + wkBuf.rewind() + woBuf.rewind() + ffnGateBuf.rewind() + ffnUpBuf.rewind() + wvBuf.rewind() + ffnDownBuf.rewind() + legacyKCacheBuf.rewind() + legacyVCacheBuf.rewind() + + val tokensBuf = allocateIntBuffer(B * T) + tokensBuf.asIntBuffer().put(tokens) + tokensBuf.rewind() + + val pipeline = getOrBuildPipeline(T, seqLen) + + val gParams = LlamaParams.create(config, T, B, startPos) + val pParams = PipelineParams(config, B, T, startPos) + + val logits = new Array[Float](B * T * V) + val logitsBuf = allocateF32Buffer(B * T * V) + + val attnStartPos = seqLen - T + val attnParamsBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + attnParamsBuf.putInt(seqLen) + attnParamsBuf.putInt(attnStartPos) + attnParamsBuf.flip() + + val region = GBufferRegion + .allocate[KVCachePipelineLayout] + .map(layout => pipeline.execute(pParams, layout)) + + region.runUnsafe( + init = KVCachePipelineLayout( + tokens = GBuffer[Int32](tokensBuf), + tokenEmbed = GBuffer[Float32](tokenEmbedBuf), + attnNorm = GBuffer[Float32](attnNormBuf), + ffnNorm = GBuffer[Float32](ffnNormBuf), + outputNorm = GBuffer[Float32](outputNormBuf), + outputWeight = GBuffer[Float32](outputWeightBuf), + wq = GBuffer[UInt32](wqBuf), + wk = GBuffer[UInt32](wkBuf), + wo = GBuffer[UInt32](woBuf), + ffnGate = GBuffer[UInt32](ffnGateBuf), + ffnUp = GBuffer[UInt32](ffnUpBuf), + ffnDown = GBuffer[UInt32](ffnDownBuf), + wv = GBuffer[UInt32](wvBuf), + kCache = GBuffer[Float32](legacyKCacheBuf), + vCache = GBuffer[Float32](legacyVCacheBuf), + hidden = GBuffer[Float32](B * T * C), + residual = GBuffer[Float32](B * T * C), + attnNormOut = GBuffer[Float32](B * T * C), + q = GBuffer[Float32](B * T * C), + k = GBuffer[Float32](B * T * kvSize), + v = GBuffer[Float32](B * T * kvSize), + qRoped = GBuffer[Float32](B * T * C), + kRoped = GBuffer[Float32](B * T * kvSize), + attnOut = GBuffer[Float32](B * T * C), + ffnNormOut = GBuffer[Float32](B * T * C), + gate = GBuffer[Float32](B * T * FFN), + up = GBuffer[Float32](B * T * FFN), + ffnHidden = GBuffer[Float32](B * T * FFN), + ffnOut = GBuffer[Float32](B * T * C), + logits = GBuffer[Float32](logitsBuf), + params = GUniform(gParams), + attnParams = GUniform[AttentionParams](attnParamsBuf), + ), + onDone = layout => + layout.kCache.read(legacyKCacheBuf) + layout.vCache.read(legacyVCacheBuf) + layout.logits.read(logitsBuf) + copyFromF32Buffer(logitsBuf, logits), + ) + + logits + end forwardWithKVCache + /** Last generation statistics. */ + private var _lastStats: GenerationStats = null + override def lastStats: Option[GenerationStats] = Option(_lastStats) + + end F32KVCachedPipeline + +end LlamaF32Pipeline diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala new file mode 100644 index 00000000..05cc0bfe --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala @@ -0,0 +1,102 @@ +package io.computenode.cyfra.llama.pipeline + +import io.computenode.cyfra.llama.model.LlamaConfig + +import java.nio.{ByteBuffer, ByteOrder} + +/** Common interface for Llama GPU pipelines. + * + * Defines the standard API for KV-cached inference pipelines: + * - generate: Efficient generation with prefill + decode in single GPU allocation + * - prefill: Process prompt tokens (legacy API) + * - decode: Generate one token (legacy API) + * + * Implementations: + * - LlamaF16Pipeline.F16KVCachedPipeline: F16 precision with Vec4 optimizations + * - LlamaF32Pipeline.F32KVCachedPipeline: F32/quantized precision (Q4_K/Q6_K) + */ +trait LlamaPipeline: + + /** Model configuration. */ + def config: LlamaConfig + + /** Current sequence length (position in KV cache). */ + def seqLen: Int + + /** Last generation statistics. */ + def lastStats: Option[GenerationStats] + + /** Generate tokens with KV cache. + * + * Optimized generation that keeps KV cache on GPU: + * - Prefill: Process all prompt tokens at once + * - Decode: Generate tokens one at a time, attending to full cache + * + * @param promptTokens Input prompt tokens + * @param maxNewTokens Maximum tokens to generate + * @param sampleFn Sampling function (logits => token) + * @param onToken Callback for each generated token + * @param stopTokens Set of tokens that stop generation + * @param reportStats If true, logs performance stats after generation + * @return Array of generated tokens (not including prompt) + */ + def generate( + promptTokens: Array[Int], + maxNewTokens: Int, + sampleFn: Array[Float] => Int, + onToken: Int => Unit, + stopTokens: Set[Int], + reportStats: Boolean, + ): Array[Int] + + /** Process prompt tokens and return logits for last position. + * + * @note Legacy API - creates new GPU allocation. Use generate() for efficient inference. + */ + def prefill(tokens: Array[Int]): Array[Float] + + /** Generate next token logits. + * + * @note Legacy API - creates new GPU allocation. Use generate() for efficient inference. + */ + def decode(token: Int): Array[Float] + +/** Performance metrics from generation. */ +case class GenerationStats( + promptTokens: Int, + generatedTokens: Int, + prefillTimeMs: Double, + decodeTimeMs: Double, + totalTimeMs: Double, +): + def prefillTokPerSec: Double = if prefillTimeMs > 0 then promptTokens * 1000.0 / prefillTimeMs else 0 + def decodeTokPerSec: Double = if decodeTimeMs > 0 then generatedTokens * 1000.0 / decodeTimeMs else 0 + def totalTokPerSec: Double = if totalTimeMs > 0 then (promptTokens + generatedTokens) * 1000.0 / totalTimeMs else 0 + + override def toString: String = + f"Gen: ${promptTokens}p+${generatedTokens}g, prefill=${prefillTokPerSec}%.0f tok/s, generate=${decodeTokPerSec}%.1f tok/s" + +/** Buffer utilities for pipeline implementations. */ +object PipelineUtils: + + def allocateF32Buffer(floatCount: Int): ByteBuffer = + ByteBuffer.allocateDirect(floatCount * 4).order(ByteOrder.nativeOrder()) + + def allocateF16Buffer(f16Count: Int): ByteBuffer = + ByteBuffer.allocateDirect(f16Count * 2).order(ByteOrder.nativeOrder()) + + def allocateIntBuffer(intCount: Int): ByteBuffer = + ByteBuffer.allocateDirect(intCount * 4).order(ByteOrder.nativeOrder()) + + def copyToF32Buffer(arr: Array[Float], buf: ByteBuffer): Unit = + buf.clear(); buf.asFloatBuffer().put(arr); buf.rewind() + + def copyFromF32Buffer(buf: ByteBuffer, arr: Array[Float]): Unit = + buf.rewind(); buf.asFloatBuffer().get(arr) + + def copyIntToBuffer(arr: Array[Int], buf: ByteBuffer): Unit = + buf.clear(); buf.asIntBuffer().put(arr); buf.rewind() + + def copyF16BytesToBuffer(bytes: Array[Byte], buf: ByteBuffer, offset: Int = 0): Unit = + buf.position(offset) + buf.put(bytes) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16CopyProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16CopyProgram.scala new file mode 100644 index 00000000..9d8281e0 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16CopyProgram.scala @@ -0,0 +1,30 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 buffer copy. Used to save state for residual connections. */ +object F16CopyProgram: + case class Sizes(size: Int) + + case class ProgramLayout( + input: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float16](s.size), + output = GBuffer[Float16](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + GIO.when(idx < sizes.size): + val value = GIO.read[Float16](layout.input, idx) + GIO.write[Float16](layout.output, idx, value) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala new file mode 100644 index 00000000..2427f19f --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala @@ -0,0 +1,43 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 token embedding lookup. + * + * Maps token IDs to embedding vectors by index lookup. + */ +object F16EmbeddingProgram: + case class Sizes(seqLen: Int, hiddenSize: Int, vocabSize: Int): + def totalOutputs: Int = seqLen * hiddenSize + + case class ProgramLayout( + tokens: GBuffer[Int32], + embeddings: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + tokens = GBuffer[Int32](s.seqLen), + embeddings = GBuffer[Float16](s.vocabSize * s.hiddenSize), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch(((s.totalOutputs + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val hiddenSizeVal: Int32 = sizes.hiddenSize + val totalVal: Int32 = sizes.seqLen * sizes.hiddenSize + + GIO.when(idx < totalVal): + val tokenPos = idx / hiddenSizeVal + val dim = idx.mod(hiddenSizeVal) + val tokenId = GIO.read[Int32](layout.tokens, tokenPos) + val embIdx = tokenId * hiddenSizeVal + dim + val value = GIO.read[Float16](layout.embeddings, embIdx) + GIO.write[Float16](layout.output, idx, value) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedGateUpSwiGLUProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedGateUpSwiGLUProgram.scala new file mode 100644 index 00000000..4a1aff67 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedGateUpSwiGLUProgram.scala @@ -0,0 +1,129 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** Fused F16 Gate + Up + SwiGLU in single dispatch. + * + * Computes the FFN gated activation in one pass: + * gate = input @ Wgate + * up = input @ Wup + * output = silu(gate) * up + * + * Where silu(x) = x * sigmoid(x) + * + * Reduces 3 dispatches → 1 by: + * - Computing gate and up projections together + * - Applying SwiGLU activation immediately after + * + * Each output element requires computing both gate and up for that position. + */ +object F16FusedGateUpSwiGLUProgram: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP + + case class Sizes( + batchSize: Int, + inFeatures: Int, // C (hidden size) + outFeatures: Int, // FFN (intermediate size) + gateOffsetVec4: Int, + upOffsetVec4: Int, + totalGateVec4: Int, + totalUpVec4: Int, + ): + require(inFeatures % 4 == 0, s"inFeatures ($inFeatures) must be divisible by 4") + def inFeaturesDiv4: Int = inFeatures / 4 + def totalOutputs: Int = batchSize * outFeatures + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numVecIterations: Int = (inFeaturesDiv4 + WARP_SIZE - 1) / WARP_SIZE + + case class ProgramLayout( + wgate: GBuffer[Vec4[Float16]], + wup: GBuffer[Vec4[Float16]], + input: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val outFeatures = sizes.outFeatures + val gateOffsetVec4 = sizes.gateOffsetVec4 + val upOffsetVec4 = sizes.upOffsetVec4 + val numVecIterations = sizes.numVecIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + wgate = GBuffer[Vec4[Float16]](s.totalGateVec4), + wup = GBuffer[Vec4[Float16]](s.totalUpVec4), + input = GBuffer[Float16](s.batchSize * s.inFeatures), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + + val inFeaturesVal: Int32 = inFeatures + val inFeaturesDiv4Val: Int32 = inFeaturesDiv4 + val outFeaturesVal: Int32 = outFeatures + val gateOffsetVec4Val: Int32 = gateOffsetVec4 + val upOffsetVec4Val: Int32 = upOffsetVec4 + val totalOutputsVal: Int32 = sizes.totalOutputs + + val outputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + + GIO.when(outputIdx < totalOutputsVal): + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + // Compute gate dot product + val gateLocalSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val gVec = GIO.read[Vec4[Float16]](layout.wgate, gateOffsetVec4Val + outIdx * inFeaturesDiv4Val + k) + val inputBase = batch * inFeaturesVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + gVec.x.asFloat32 * x0 + gVec.y.asFloat32 * x1 + gVec.z.asFloat32 * x2 + gVec.w.asFloat32 * x3 + .otherwise(sum) + ) + + // Compute up dot product + val upLocalSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val uVec = GIO.read[Vec4[Float16]](layout.wup, upOffsetVec4Val + outIdx * inFeaturesDiv4Val + k) + val inputBase = batch * inFeaturesVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + uVec.x.asFloat32 * x0 + uVec.y.asFloat32 * x1 + uVec.z.asFloat32 * x2 + uVec.w.asFloat32 * x3 + .otherwise(sum) + ) + + // Reduce across warp + val gate = GIO.subgroupAdd(gateLocalSum) + val up = GIO.subgroupAdd(upLocalSum) + + // SwiGLU: silu(gate) * up = gate * sigmoid(gate) * up + val sigmoidGate = 1.0f / (1.0f + exp(-gate)) + val result = gate * sigmoidGate * up + + GIO.write[Float16](layout.output, outputIdx, result.asFloat16) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala new file mode 100644 index 00000000..2a3ff4f5 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala @@ -0,0 +1,105 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Fused F16 KV Cache Write for both K and V in single dispatch. + * + * Reduces dispatch count by writing both K and V vectors to cache simultaneously. + * Each invocation copies one element from either K or V input to the cache. + */ +object F16FusedKVCacheWriteProgram: + + case class Sizes( + B: Int, + T: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + layer: Int, + posOffset: Int, + kCacheLayerOffset: Int, + vCacheLayerOffset: Int, + L: Int, + ): + def totalKElements: Int = B * T * NKV * headSize + def totalVElements: Int = B * T * NKV * headSize + def totalElements: Int = totalKElements + totalVElements + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + k: GBuffer[Float16], + v: GBuffer[Float16], + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NKV = sizes.NKV + val headSize = sizes.headSize + val totalKElements = sizes.totalKElements + val totalElements = sizes.totalElements + val kCacheLayerOffset = sizes.kCacheLayerOffset + val vCacheLayerOffset = sizes.vCacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + k = GBuffer[Float16](s.totalKElements), + v = GBuffer[Float16](s.totalVElements), + kCache = GBuffer[Float16](s.fullCacheSize), + vCache = GBuffer[Float16](s.fullCacheSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch(((s.totalElements + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val posOffsetVal: Int32 = layout.params.read.startPos + + val Tval: Int32 = T + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val totalKElementsVal: Int32 = totalKElements + val totalElementsVal: Int32 = totalElements + val kCacheLayerOffsetVal: Int32 = kCacheLayerOffset + val vCacheLayerOffsetVal: Int32 = vCacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + GIO.when(idx < totalElementsVal): + // Determine if this is K or V + val isK = idx < totalKElementsVal + val localIdx: Int32 = when(isK)(idx).otherwise(idx - totalKElementsVal) + + // Decompose index + val elementsPerBatch = Tval * NKVval * headSizeVal + val b = localIdx / elementsPerBatch + val remaining1 = localIdx.mod(elementsPerBatch) + val elementsPerPos = NKVval * headSizeVal + val t = remaining1 / elementsPerPos + val remaining2 = remaining1.mod(elementsPerPos) + val h = remaining2 / headSizeVal + val d = remaining2.mod(headSizeVal) + + val cachePos = posOffsetVal + t + val cacheOffset: Int32 = cachePos * kvSizePerPosVal + h * headSizeVal + d + + for + _ <- GIO.when(isK): + val kVal = GIO.read[Float16](layout.k, localIdx) + val cacheIdx = kCacheLayerOffsetVal + cacheOffset + GIO.write[Float16](layout.kCache, cacheIdx, kVal) + _ <- GIO.when(!isK): + val vVal = GIO.read[Float16](layout.v, localIdx) + val cacheIdx = vCacheLayerOffsetVal + cacheOffset + GIO.write[Float16](layout.vCache, cacheIdx, vVal) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala new file mode 100644 index 00000000..5a9ce4e1 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala @@ -0,0 +1,149 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** Fused F16 Q/K/V projection in single dispatch. + * + * Computes all three attention projections at once: + * - Q = input @ Wq (outFeatures = C) + * - K = input @ Wk (outFeatures = kvSize) + * - V = input @ Wv (outFeatures = kvSize) + * + * Reduces 3 dispatches → 1, sharing the same input read. + * Each warp computes one output element from Q, K, or V. + */ +object F16FusedQKVMatmulProgram: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP + + case class Sizes( + batchSize: Int, + inFeatures: Int, // C (hidden size) + qOutFeatures: Int, // C (for Q) + kvOutFeatures: Int, // kvSize = NKV * headSize (for K and V) + wqOffsetVec4: Int, + wkOffsetVec4: Int, + wvOffsetVec4: Int, + totalWqVec4: Int, + totalWkVec4: Int, + totalWvVec4: Int, + ): + require(inFeatures % 4 == 0, s"inFeatures ($inFeatures) must be divisible by 4") + def inFeaturesDiv4: Int = inFeatures / 4 + def totalQOutputs: Int = batchSize * qOutFeatures + def totalKOutputs: Int = batchSize * kvOutFeatures + def totalVOutputs: Int = batchSize * kvOutFeatures + def totalOutputs: Int = totalQOutputs + totalKOutputs + totalVOutputs + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numVecIterations: Int = (inFeaturesDiv4 + WARP_SIZE - 1) / WARP_SIZE + + case class ProgramLayout( + wq: GBuffer[Vec4[Float16]], + wk: GBuffer[Vec4[Float16]], + wv: GBuffer[Vec4[Float16]], + input: GBuffer[Float16], + q: GBuffer[Float16], + k: GBuffer[Float16], + v: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val qOutFeatures = sizes.qOutFeatures + val kvOutFeatures = sizes.kvOutFeatures + val wqOffsetVec4 = sizes.wqOffsetVec4 + val wkOffsetVec4 = sizes.wkOffsetVec4 + val wvOffsetVec4 = sizes.wvOffsetVec4 + val numVecIterations = sizes.numVecIterations + val totalQOutputs = sizes.totalQOutputs + val totalKOutputs = sizes.totalKOutputs + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + wq = GBuffer[Vec4[Float16]](s.totalWqVec4), + wk = GBuffer[Vec4[Float16]](s.totalWkVec4), + wv = GBuffer[Vec4[Float16]](s.totalWvVec4), + input = GBuffer[Float16](s.batchSize * s.inFeatures), + q = GBuffer[Float16](s.totalQOutputs), + k = GBuffer[Float16](s.totalKOutputs), + v = GBuffer[Float16](s.totalVOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + + val inFeaturesVal: Int32 = inFeatures + val inFeaturesDiv4Val: Int32 = inFeaturesDiv4 + val qOutFeaturesVal: Int32 = qOutFeatures + val kvOutFeaturesVal: Int32 = kvOutFeatures + val wqOffsetVec4Val: Int32 = wqOffsetVec4 + val wkOffsetVec4Val: Int32 = wkOffsetVec4 + val wvOffsetVec4Val: Int32 = wvOffsetVec4 + val totalQOutputsVal: Int32 = totalQOutputs + val totalKOutputsVal: Int32 = totalKOutputs + val totalOutputsVal: Int32 = sizes.totalOutputs + + val globalOutputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + + // Determine which output (Q, K, or V) this warp handles + // Q: indices [0, totalQOutputs) + // K: indices [totalQOutputs, totalQOutputs + totalKOutputs) + // V: indices [totalQOutputs + totalKOutputs, total) + + val isQ = globalOutputIdx < totalQOutputsVal + val isK = !isQ && (globalOutputIdx < totalQOutputsVal + totalKOutputsVal) + val isV = !isQ && !isK + + // Helper function to compute matmul for a given buffer and offset + def computeMatmul( + weightBuffer: GBuffer[Vec4[Float16]], + weightOffset: Int32, + outFeatures: Int32, + localIdx: Int32, + ): Float32 = + val batch = localIdx / outFeatures + val outIdx = localIdx.mod(outFeatures) + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val wVec = GIO.read[Vec4[Float16]](weightBuffer, weightOffset + outIdx * inFeaturesDiv4Val + k) + val inputBase = batch * inFeaturesVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + wVec.x.asFloat32 * x0 + wVec.y.asFloat32 * x1 + wVec.z.asFloat32 * x2 + wVec.w.asFloat32 * x3 + .otherwise(sum) + ) + GIO.subgroupAdd(localSum) + + // Process Q outputs + for + _ <- GIO.when(isQ && globalOutputIdx < totalOutputsVal): + val localIdx = globalOutputIdx + val result = computeMatmul(layout.wq, wqOffsetVec4Val, qOutFeaturesVal, localIdx) + GIO.write[Float16](layout.q, localIdx, result.asFloat16) + // Process K outputs + _ <- GIO.when(isK && globalOutputIdx < totalOutputsVal): + val localIdx = globalOutputIdx - totalQOutputsVal + val result = computeMatmul(layout.wk, wkOffsetVec4Val, kvOutFeaturesVal, localIdx) + GIO.write[Float16](layout.k, localIdx, result.asFloat16) + // Process V outputs + _ <- GIO.when(isV && globalOutputIdx < totalOutputsVal): + val localIdx = globalOutputIdx - totalQOutputsVal - totalKOutputsVal + val result = computeMatmul(layout.wv, wvOffsetVec4Val, kvOutFeaturesVal, localIdx) + GIO.write[Float16](layout.v, localIdx, result.asFloat16) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala new file mode 100644 index 00000000..2e068281 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala @@ -0,0 +1,124 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Fused F16 Rotary Position Embedding for both Q and K in single dispatch. + * + * Reduces dispatch count by processing both Q and K tensors simultaneously. + * Each invocation handles one pair from either Q or K. + */ +object F16FusedRoPEProgram: + val BLOCK_SIZE = 256 + + case class Sizes( + B: Int, + T: Int, + numHeadsQ: Int, + numHeadsK: Int, + headSize: Int, + theta: Float, + ): + def totalQPairs: Int = B * T * numHeadsQ * (headSize / 2) + def totalKPairs: Int = B * T * numHeadsK * (headSize / 2) + def totalPairs: Int = totalQPairs + totalKPairs + + case class ProgramLayout( + qIn: GBuffer[Float16], + kIn: GBuffer[Float16], + qOut: GBuffer[Float16], + kOut: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val numHeadsQ = sizes.numHeadsQ + val numHeadsK = sizes.numHeadsK + val headSize = sizes.headSize + val theta = sizes.theta + val totalQPairs = sizes.totalQPairs + val totalKPairs = sizes.totalKPairs + val totalPairs = sizes.totalPairs + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + qIn = GBuffer[Float16](s.B * s.T * s.numHeadsQ * s.headSize), + kIn = GBuffer[Float16](s.B * s.T * s.numHeadsK * s.headSize), + qOut = GBuffer[Float16](s.B * s.T * s.numHeadsQ * s.headSize), + kOut = GBuffer[Float16](s.B * s.T * s.numHeadsK * s.headSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch(((s.totalPairs + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val totalPairsVal: Int32 = totalPairs + val totalQPairsVal: Int32 = totalQPairs + val Tval: Int32 = T + val numHeadsQVal: Int32 = numHeadsQ + val numHeadsKVal: Int32 = numHeadsK + val halfHead: Int32 = headSize / 2 + val headSizeVal: Int32 = headSize + val thetaVal: Float32 = theta + val startPosVal: Int32 = layout.params.read.startPos + + GIO.when(idx < totalPairsVal): + // Determine if this is Q or K based on index + val isQ = idx < totalQPairsVal + + // Calculate local index within Q or K + val localIdx: Int32 = when(isQ)(idx).otherwise(idx - totalQPairsVal) + val numHeads: Int32 = when(isQ)(numHeadsQVal).otherwise(numHeadsKVal) + + // Decompose index + val perHead = halfHead + val perPos = numHeads * halfHead + val perBatch = Tval * perPos + + val b = localIdx / perBatch + val rem1 = localIdx.mod(perBatch) + val t = rem1 / perPos + val rem2 = rem1.mod(perPos) + val h = rem2 / perHead + val d = rem2.mod(perHead) + + // Compute RoPE rotation + val pos = startPosVal + t + val headSizeFloat: Float32 = headSize.toFloat + val freqExponent: Float32 = -2.0f * d.asFloat / headSizeFloat + val freq: Float32 = pos.asFloat * pow(thetaVal, freqExponent) + val cosFreq = cos(freq).asFloat16 + val sinFreq = sin(freq).asFloat16 + + // Calculate full indices for reading/writing + val fullIdx: Int32 = b * Tval * numHeads * headSizeVal + t * numHeads * headSizeVal + h * headSizeVal + val idx0 = fullIdx + d * 2 + val idx1 = idx0 + 1 + + // Read, rotate, write - branch on Q vs K + for + _ <- GIO.when(isQ): + val x0 = GIO.read[Float16](layout.qIn, idx0) + val x1 = GIO.read[Float16](layout.qIn, idx1) + val y0 = x0 * cosFreq - x1 * sinFreq + val y1 = x0 * sinFreq + x1 * cosFreq + for + _ <- GIO.write[Float16](layout.qOut, idx0, y0) + _ <- GIO.write[Float16](layout.qOut, idx1, y1) + yield GStruct.Empty() + _ <- GIO.when(!isQ): + val x0 = GIO.read[Float16](layout.kIn, idx0) + val x1 = GIO.read[Float16](layout.kIn, idx1) + val y0 = x0 * cosFreq - x1 * sinFreq + val y1 = x0 * sinFreq + x1 * cosFreq + for + _ <- GIO.write[Float16](layout.kOut, idx0, y0) + _ <- GIO.write[Float16](layout.kOut, idx1, y1) + yield GStruct.Empty() + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteK.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteK.scala new file mode 100644 index 00000000..d4803739 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteK.scala @@ -0,0 +1,87 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Writes K vectors to KV cache at specified positions (F16 version). + * + * Each invocation copies one element from the input K tensor to the KV cache + * at the position specified by the runtime `startPos` parameter. + * + * @note The cache is organized as (L × maxSeqLen × NKV × headSize) where L is total layers. + * This program writes to a single layer's slice using `cacheLayerOffset`. + */ +object F16KVCacheWriteK: + + /** Compile-time size parameters for the KV cache write K program. */ + case class Sizes( + B: Int, + T: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + layer: Int, + posOffset: Int, + cacheLayerOffset: Int, + L: Int, + ): + def totalElements: Int = B * T * NKV * headSize + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + k: GBuffer[Float16], + kCache: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + /** Creates a GPU program that writes K vectors to the KV cache. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NKV = sizes.NKV + val headSize = sizes.headSize + val totalElements = sizes.totalElements + val cacheLayerOffset = sizes.cacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + k = GBuffer[Float16](s.B * s.T * s.NKV * s.headSize), + kCache = GBuffer[Float16](s.fullCacheSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.totalElements, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val posOffsetVal: Int32 = layout.params.read.startPos + + val Tval: Int32 = T + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val totalElementsVal: Int32 = totalElements + val cacheLayerOffsetVal: Int32 = cacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + GIO.when(idx < totalElementsVal): + val elementsPerBatch = Tval * NKVval * headSizeVal + val b = idx / elementsPerBatch + val remaining1 = idx.mod(elementsPerBatch) + val elementsPerPos = NKVval * headSizeVal + val t = remaining1 / elementsPerPos + val remaining2 = remaining1.mod(elementsPerPos) + val h = remaining2 / headSizeVal + val d = remaining2.mod(headSizeVal) + + val inputIdx = idx + val kVal = GIO.read[Float16](layout.k, inputIdx) + + val cachePos = posOffsetVal + t + val cacheIdx = cacheLayerOffsetVal + cachePos * kvSizePerPosVal + h * headSizeVal + d + GIO.write[Float16](layout.kCache, cacheIdx, kVal) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteV.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteV.scala new file mode 100644 index 00000000..bf840b8a --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCacheWriteV.scala @@ -0,0 +1,83 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Writes V vectors to KV cache at specified positions (F16 version). + * + * Identical structure to F16KVCacheWriteK but operates on V vectors. + */ +object F16KVCacheWriteV: + + /** Compile-time size parameters for the KV cache write V program. */ + case class Sizes( + B: Int, + T: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + layer: Int, + posOffset: Int, + cacheLayerOffset: Int, + L: Int, + ): + def totalElements: Int = B * T * NKV * headSize + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + v: GBuffer[Float16], + vCache: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + /** Creates a GPU program that writes V vectors to the KV cache. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NKV = sizes.NKV + val headSize = sizes.headSize + val totalElements = sizes.totalElements + val cacheLayerOffset = sizes.cacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + v = GBuffer[Float16](s.B * s.T * s.NKV * s.headSize), + vCache = GBuffer[Float16](s.fullCacheSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.totalElements, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val posOffsetVal: Int32 = layout.params.read.startPos + + val Tval: Int32 = T + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val totalElementsVal: Int32 = totalElements + val cacheLayerOffsetVal: Int32 = cacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + GIO.when(idx < totalElementsVal): + val elementsPerBatch = Tval * NKVval * headSizeVal + val b = idx / elementsPerBatch + val remaining1 = idx.mod(elementsPerBatch) + val elementsPerPos = NKVval * headSizeVal + val t = remaining1 / elementsPerPos + val remaining2 = remaining1.mod(elementsPerPos) + val h = remaining2 / headSizeVal + val d = remaining2.mod(headSizeVal) + + val inputIdx = idx + val vVal = GIO.read[Float16](layout.v, inputIdx) + + val cachePos = posOffsetVal + t + val cacheIdx = cacheLayerOffsetVal + cachePos * kvSizePerPosVal + h * headSizeVal + d + GIO.write[Float16](layout.vCache, cacheIdx, vVal) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCachedAttention.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCachedAttention.scala new file mode 100644 index 00000000..5460e516 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16KVCachedAttention.scala @@ -0,0 +1,174 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.GShared +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.llama.programs.AttentionParams + +/** KV-cached attention for incremental inference (F16 version). + * + * Computes attention by reading Q from current tokens and K/V from the full cache. + * Uses F32 for intermediate score computations for numerical stability, F16 for storage. + * + * One workgroup handles one (batch, query_position, head) tuple. + * Supports grouped-query attention (GQA) where multiple Q heads share K/V heads. + */ +object F16KVCachedAttention: + val WARP_SIZE = 32 + val MAX_SEQ_LEN = 2048 + + /** Compile-time size parameters for the attention program. */ + case class Sizes( + B: Int, + T: Int, + NH: Int, + NKV: Int, + headSize: Int, + startPos: Int, + kCacheLayerOffset: Int, + vCacheLayerOffset: Int, + L: Int, + maxSeqLen: Int, + ): + def gqaRatio: Int = NH / NKV + def numScoreIterations: Int = (maxSeqLen + WARP_SIZE - 1) / WARP_SIZE + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + q: GBuffer[Float16], + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + output: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + /** Creates a GPU program for KV-cached attention computation. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val scoresShared = GShared[Float32](MAX_SEQ_LEN) + + val B = sizes.B + val T = sizes.T + val NH = sizes.NH + val NKV = sizes.NKV + val headSize = sizes.headSize + val gqaRatio = sizes.gqaRatio + val scale = 1.0f / math.sqrt(headSize).toFloat + val numScoreIterations = sizes.numScoreIterations + val kCacheLayerOffset = sizes.kCacheLayerOffset + val vCacheLayerOffset = sizes.vCacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + val maxSeqLen = sizes.maxSeqLen + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + q = GBuffer[Float16](s.B * s.T * s.NH * s.headSize), + kCache = GBuffer[Float16](s.fullCacheSize), + vCache = GBuffer[Float16](s.fullCacheSize), + output = GBuffer[Float16](s.B * s.T * s.NH * s.headSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.B * s.T * s.NH, 1, 1)), + workgroupSize = (WARP_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + + val runtimeParams = layout.params.read + val seqLenVal: Int32 = runtimeParams.seqLen + val startPosVal: Int32 = runtimeParams.startPos + + val Tval: Int32 = T + val NHval: Int32 = NH + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val gqaRatioVal: Int32 = gqaRatio + val scaleVal: Float32 = scale + val kCacheLayerOffsetVal: Int32 = kCacheLayerOffset + val vCacheLayerOffsetVal: Int32 = vCacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + val posPerBatch = Tval * NHval + val batchIdx = workgroupId / posPerBatch + val posInBatch = workgroupId.mod(posPerBatch) + val queryPosLocal = posInBatch / NHval + val headIdx = posInBatch.mod(NHval) + val kvHeadIdx = headIdx / gqaRatioVal + + val queryPosGlobal = startPosVal + queryPosLocal + val qBase = batchIdx * Tval * NHval * headSizeVal + queryPosLocal * NHval * headSizeVal + headIdx * headSizeVal + + // Phase 1: Compute attention scores Q·K and track max for numerical stability + val computeScoresAndMax: GIO[Float32] = GIO.foldRepeat[Float32](numScoreIterations, -10000.0f): (iter, localMax) => + val kPos = tid + iter * WARP_SIZE + val isValid = kPos <= queryPosGlobal && kPos < seqLenVal + val kCacheBase = kCacheLayerOffsetVal + kPos * kvSizePerPosVal + kvHeadIdx * headSizeVal + + val dot = GSeq.gen[Int32](0, _ + 1).limit(headSize).unroll.fold(0.0f, (acc: Float32, d: Int32) => + val qVal = GIO.read[Float16](layout.q, qBase + d).asFloat32 + val kVal = GIO.read[Float16](layout.kCache, kCacheBase + d).asFloat32 + acc + qVal * kVal + ) + + val score = when(isValid)(dot * scaleVal).otherwise(-10000.0f) + + for _ <- scoresShared.write(kPos, score) yield + when(isValid)(max(localMax, score)).otherwise(localMax) + + for + localMax <- computeScoresAndMax + _ <- GIO.barrier + + // Phase 2: Compute softmax numerator exp(score - max) + globalMax <- GIO.pure(GIO.subgroupMax(localMax)) + localSum <- GIO.foldRepeat[Float32](numScoreIterations, 0.0f): (iter, sum) => + val kPos = tid + iter * WARP_SIZE + val isValid = kPos <= queryPosGlobal && kPos < seqLenVal + val score = scoresShared.read(kPos) + val expScore = exp(score - globalMax) + for _ <- scoresShared.write(kPos, expScore) yield + when(isValid)(sum + expScore).otherwise(sum) + + _ <- GIO.barrier + + // Phase 3: Normalize to get attention weights + globalSum <- GIO.pure(GIO.subgroupAdd(localSum) + 0.0000001f) + _ <- GIO.foldRepeat[Empty](numScoreIterations, Empty()): (iter, _) => + val kPos = tid + iter * WARP_SIZE + val isValid = kPos <= queryPosGlobal && kPos < seqLenVal + val expScore = scoresShared.read(kPos) + for _ <- GIO.when(isValid)(scoresShared.write(kPos, expScore / globalSum)) yield Empty() + + _ <- GIO.barrier + + // Phase 4: Compute weighted sum of V values + outDim1 <- GIO.pure(tid) + _ <- GIO.when(outDim1 < headSizeVal): + val weightedSum1 = GSeq.gen[Int32](0, _ + 1).limit(maxSeqLen).takeWhile(_ < seqLenVal).fold(0.0f, (sum: Float32, kPos: Int32) => + val isValid = kPos <= queryPosGlobal + val weight = scoresShared.read(kPos) + val vCacheBase = vCacheLayerOffsetVal + kPos * kvSizePerPosVal + kvHeadIdx * headSizeVal + val vVal = GIO.read[Float16](layout.vCache, vCacheBase + outDim1).asFloat32 + when(isValid)(sum + weight * vVal).otherwise(sum) + ) + val outBase = batchIdx * Tval * NHval * headSizeVal + queryPosLocal * NHval * headSizeVal + headIdx * headSizeVal + GIO.write[Float16](layout.output, outBase + outDim1, weightedSum1.asFloat16) + + // Handle dimensions beyond WARP_SIZE (for headSize > 32) + outDim2 <- GIO.pure(tid + WARP_SIZE) + _ <- GIO.when(outDim2 < headSizeVal): + val weightedSum2 = GSeq.gen[Int32](0, _ + 1).limit(maxSeqLen).takeWhile(_ < seqLenVal).fold(0.0f, (sum: Float32, kPos: Int32) => + val isValid = kPos <= queryPosGlobal + val weight = scoresShared.read(kPos) + val vCacheBase = vCacheLayerOffsetVal + kPos * kvSizePerPosVal + kvHeadIdx * headSizeVal + val vVal = GIO.read[Float16](layout.vCache, vCacheBase + outDim2).asFloat32 + when(isValid)(sum + weight * vVal).otherwise(sum) + ) + val outBase = batchIdx * Tval * NHval * headSizeVal + queryPosLocal * NHval * headSizeVal + headIdx * headSizeVal + GIO.write[Float16](layout.output, outBase + outDim2, weightedSum2.asFloat16) + yield Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulResidualAddProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulResidualAddProgram.scala new file mode 100644 index 00000000..6772ed1b --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulResidualAddProgram.scala @@ -0,0 +1,157 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** Fused F16 matrix-vector multiply + residual add. + * + * Computes: output[i] = residual[i] + dot(weight[i], input) + * + * Eliminates the need for separate Copy + ResidualAdd programs by: + * 1. Reading residual during computation (no need to copy first) + * 2. Adding residual directly to matmul result + * + * This reduces dispatch count by 2 per usage (Copy + ResidualAdd → nothing). + */ +object F16MatmulResidualAddProgram: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP + + case class Sizes( + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + weightOffsetVec4: Int = 0, + totalWeightVec4: Int = -1, + ): + require(inFeatures % 4 == 0, s"inFeatures ($inFeatures) must be divisible by 4") + def inFeaturesDiv4: Int = inFeatures / 4 + def totalOutputs: Int = batchSize * outFeatures + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numVecIterations: Int = (inFeaturesDiv4 + WARP_SIZE - 1) / WARP_SIZE + def actualWeightVec4: Int = if totalWeightVec4 < 0 then outFeatures * inFeaturesDiv4 else totalWeightVec4 + + /** Layout with Vec4 input for optimal memory bandwidth. */ + case class ProgramLayoutVec4( + weight: GBuffer[Vec4[Float16]], + input: GBuffer[Vec4[Float16]], + residual: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + /** Legacy layout with scalar input. */ + case class ProgramLayout( + weight: GBuffer[Vec4[Float16]], + input: GBuffer[Float16], + residual: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + /** Optimized forward with Vec4 input reads. */ + def forwardVec4(sizes: Sizes): GProgram[Sizes, ProgramLayoutVec4] = + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val outFeatures = sizes.outFeatures + val weightOffsetVec4 = sizes.weightOffsetVec4 + val numVecIterations = sizes.numVecIterations + + GProgram[Sizes, ProgramLayoutVec4]( + layout = s => ProgramLayoutVec4( + weight = GBuffer[Vec4[Float16]](s.actualWeightVec4), + input = GBuffer[Vec4[Float16]](s.batchSize * s.inFeaturesDiv4), + residual = GBuffer[Float16](s.totalOutputs), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + val inFeaturesDiv4Val: Int32 = inFeaturesDiv4 + val outFeaturesVal: Int32 = outFeatures + val weightOffsetVec4Val: Int32 = weightOffsetVec4 + val totalOutputsVal: Int32 = sizes.totalOutputs + + val outputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val wVec = GIO.read[Vec4[Float16]](layout.weight, weightOffsetVec4Val + outIdx * inFeaturesDiv4Val + k) + val xVec = GIO.read[Vec4[Float16]](layout.input, batch * inFeaturesDiv4Val + k) + val wF32 = vec4(wVec.x.asFloat32, wVec.y.asFloat32, wVec.z.asFloat32, wVec.w.asFloat32) + val xF32 = vec4(xVec.x.asFloat32, xVec.y.asFloat32, xVec.z.asFloat32, xVec.w.asFloat32) + sum + wF32.dot(xF32) + .otherwise(sum) + ) + + val totalSum = GIO.subgroupAdd(localSum) + GIO.when(outputIdx < totalOutputsVal): + val residualVal = GIO.read[Float16](layout.residual, outputIdx).asFloat32 + val result = totalSum + residualVal + GIO.write[Float16](layout.output, outputIdx, result.asFloat16) + + /** Legacy forward with scalar input reads. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val outFeatures = sizes.outFeatures + val weightOffsetVec4 = sizes.weightOffsetVec4 + val numVecIterations = sizes.numVecIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[Vec4[Float16]](s.actualWeightVec4), + input = GBuffer[Float16](s.batchSize * s.inFeatures), + residual = GBuffer[Float16](s.totalOutputs), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val inFeaturesDiv4Val: Int32 = inFeaturesDiv4 + val outFeaturesVal: Int32 = outFeatures + val weightOffsetVec4Val: Int32 = weightOffsetVec4 + val totalOutputsVal: Int32 = sizes.totalOutputs + + val outputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val wVec = GIO.read[Vec4[Float16]](layout.weight, weightOffsetVec4Val + outIdx * inFeaturesDiv4Val + k) + val inputBase = batch * inFeaturesVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + wVec.x.asFloat32 * x0 + wVec.y.asFloat32 * x1 + wVec.z.asFloat32 * x2 + wVec.w.asFloat32 * x3 + .otherwise(sum) + ) + + val totalSum = GIO.subgroupAdd(localSum) + GIO.when(outputIdx < totalOutputsVal): + // Read residual and add to matmul result + val residualVal = GIO.read[Float16](layout.residual, outputIdx).asFloat32 + val result = totalSum + residualVal + GIO.write[Float16](layout.output, outputIdx, result.asFloat16) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala new file mode 100644 index 00000000..adfb395b --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala @@ -0,0 +1,89 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 matrix-vector multiply with Vec4-packed weights. + * + * Uses Vec4[Float16] weights for 4x memory bandwidth while keeping scalar input. + * Optimal for activation-weight multiplies where weights are static but activations vary. + * + * @note Requires `inFeatures` divisible by 4 for Vec4 alignment. + */ +object F16MatmulVecHybridProgram: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP + + case class Sizes( + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + weightOffsetVec4: Int = 0, + totalWeightVec4: Int = -1, + ): + require(inFeatures % 4 == 0, s"inFeatures ($inFeatures) must be divisible by 4") + def inFeaturesDiv4: Int = inFeatures / 4 + def totalOutputs: Int = batchSize * outFeatures + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numVecIterations: Int = (inFeaturesDiv4 + WARP_SIZE - 1) / WARP_SIZE + def actualWeightVec4: Int = if totalWeightVec4 < 0 then outFeatures * inFeaturesDiv4 else totalWeightVec4 + + case class ProgramLayout( + weight: GBuffer[Vec4[Float16]], + input: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val outFeatures = sizes.outFeatures + val weightOffsetVec4 = sizes.weightOffsetVec4 + val numVecIterations = sizes.numVecIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[Vec4[Float16]](s.actualWeightVec4), + input = GBuffer[Float16](s.batchSize * s.inFeatures), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val inFeaturesDiv4Val: Int32 = inFeaturesDiv4 + val outFeaturesVal: Int32 = outFeatures + val weightOffsetVec4Val: Int32 = weightOffsetVec4 + val totalOutputsVal: Int32 = sizes.totalOutputs + + val outputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val wVec = GIO.read[Vec4[Float16]](layout.weight, weightOffsetVec4Val + outIdx * inFeaturesDiv4Val + k) + val inputBase = batch * inFeaturesVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + wVec.x.asFloat32 * x0 + wVec.y.asFloat32 * x1 + wVec.z.asFloat32 * x2 + wVec.w.asFloat32 * x3 + .otherwise(sum) + ) + + val totalSum = GIO.subgroupAdd(localSum) + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float16](layout.output, outputIdx, totalSum.asFloat16) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala new file mode 100644 index 00000000..882b8f59 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala @@ -0,0 +1,80 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 output projection with Vec4-packed weights. + * + * Projects hidden states to vocabulary logits using Vec4[Float16] weights. + * Output is F32 for softmax numerical stability. + * + * @note Requires `hiddenSize` divisible by 4 for Vec4 alignment. + */ +object F16OutputVec4Program: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP + + case class Sizes(batchSize: Int, hiddenSize: Int, vocabSize: Int): + require(hiddenSize % 4 == 0, s"hiddenSize ($hiddenSize) must be divisible by 4") + def hiddenSizeDiv4: Int = hiddenSize / 4 + def totalOutputs: Int = batchSize * vocabSize + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numVecIterations: Int = (hiddenSizeDiv4 + WARP_SIZE - 1) / WARP_SIZE + + case class ProgramLayout( + input: GBuffer[Float16], + weight: GBuffer[Vec4[Float16]], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val hiddenSize = sizes.hiddenSize + val hiddenSizeDiv4 = sizes.hiddenSizeDiv4 + val vocabSize = sizes.vocabSize + val numVecIterations = sizes.numVecIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float16](s.batchSize * s.hiddenSize), + weight = GBuffer[Vec4[Float16]](s.vocabSize * s.hiddenSizeDiv4), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + val hiddenSizeVal: Int32 = hiddenSize + val hiddenSizeDiv4Val: Int32 = hiddenSizeDiv4 + val vocabSizeVal: Int32 = vocabSize + val totalOutputsVal: Int32 = sizes.totalOutputs + + val outputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + val batch = outputIdx / vocabSizeVal + val vocabIdx = outputIdx.mod(vocabSizeVal) + + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < hiddenSizeDiv4Val): + val wVec = GIO.read[Vec4[Float16]](layout.weight, vocabIdx * hiddenSizeDiv4Val + k) + val inputBase = batch * hiddenSizeVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + wVec.x.asFloat32 * x0 + wVec.y.asFloat32 * x1 + wVec.z.asFloat32 * x2 + wVec.w.asFloat32 * x3 + .otherwise(sum) + ) + + val totalSum = GIO.subgroupAdd(localSum) + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float32](layout.output, outputIdx, totalSum) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala new file mode 100644 index 00000000..a277cae4 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala @@ -0,0 +1,83 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 Root Mean Square Layer Normalization. + * + * Normalizes input by RMS: `output[i] = input[i] / rms(input) * weight[i]`. + * Accumulates in F32 for numerical precision. + */ +object F16RMSNormProgram: + val WARP_SIZE = 32 + + case class Sizes( + numRows: Int, + rowSize: Int, + eps: Float, + weightOffset: Int = 0, + totalWeightSize: Int = -1, + ): + def numIterations: Int = (rowSize + WARP_SIZE - 1) / WARP_SIZE + def actualWeightSize: Int = if totalWeightSize < 0 then rowSize else totalWeightSize + + case class ProgramLayout( + input: GBuffer[Float16], + weight: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val rowSize = sizes.rowSize + val eps = sizes.eps + val numIterations = sizes.numIterations + val weightOffset = sizes.weightOffset + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float16](s.numRows * s.rowSize), + weight = GBuffer[Float16](s.actualWeightSize), + output = GBuffer[Float16](s.numRows * s.rowSize), + ), + dispatch = (_, s) => + val warpsPerWorkgroup = 8 + val numWorkgroups = (s.numRows + warpsPerWorkgroup - 1) / warpsPerWorkgroup + StaticDispatch((numWorkgroups, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val globalId = GIO.invocationId + val laneId = globalId.mod(WARP_SIZE) + val subgroupIdx = globalId / WARP_SIZE + val rowSizeVal: Int32 = rowSize + val epsVal: Float32 = eps + val weightOffsetVal: Int32 = weightOffset + val numRowsVal: Int32 = sizes.numRows + + GIO.when(subgroupIdx < numRowsVal): + val rowIdx = subgroupIdx + val baseIdx = rowIdx * rowSizeVal + + val localSumSq = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numIterations) + .fold(0.0f, (sum: Float32, i: Int32) => + when(i < rowSizeVal): + val x = GIO.read[Float16](layout.input, baseIdx + i).asFloat32 + sum + (x * x) + .otherwise(sum) + ) + + val totalSumSq = GIO.subgroupAdd(localSumSq) + val rowSizeF32 = rowSizeVal.asFloat + val meanSq: Float32 = totalSumSq / rowSizeF32 + val scale: Float32 = 1.0f / sqrt(meanSq + epsVal) + + GIO.repeat(numIterations): j => + val i = laneId + j * WARP_SIZE + GIO.when(i < rowSizeVal): + val x = GIO.read[Float16](layout.input, baseIdx + i).asFloat32 + val w = GIO.read[Float16](layout.weight, weightOffsetVal + i).asFloat32 + GIO.write[Float16](layout.output, baseIdx + i, (x * scale * w).asFloat16) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala new file mode 100644 index 00000000..182c4948 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala @@ -0,0 +1,33 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 element-wise addition for residual connections: `output = a + b`. */ +object F16ResidualAddProgram: + case class Sizes(size: Int) + + case class ProgramLayout( + a: GBuffer[Float16], + b: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + a = GBuffer[Float16](s.size), + b = GBuffer[Float16](s.size), + output = GBuffer[Float16](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + GIO.when(idx < sizes.size): + val aVal = GIO.read[Float16](layout.a, idx) + val bVal = GIO.read[Float16](layout.b, idx) + GIO.write[Float16](layout.output, idx, aVal + bVal) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RoPEProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RoPEProgram.scala new file mode 100644 index 00000000..32fcb36c --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RoPEProgram.scala @@ -0,0 +1,89 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** F16 Rotary Position Embedding (RoPE). + * + * Applies rotary embeddings to encode positional information via rotation. + * Operates on pairs of consecutive dimensions with position-dependent frequencies. + */ +object F16RoPEProgram: + val BLOCK_SIZE = 256 + + case class Sizes( + B: Int, + T: Int, + numHeads: Int, + headSize: Int, + theta: Float, + ): + def totalElements: Int = B * T * numHeads * headSize + def totalPairs: Int = B * T * numHeads * (headSize / 2) + + case class ProgramLayout( + input: GBuffer[Float16], + output: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val numHeads = sizes.numHeads + val headSize = sizes.headSize + val theta = sizes.theta + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float16](s.totalElements), + output = GBuffer[Float16](s.totalElements), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch(((s.totalPairs + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val totalPairsVal: Int32 = B * T * numHeads * (headSize / 2) + val Tval: Int32 = T + val numHeadsVal: Int32 = numHeads + val halfHead: Int32 = headSize / 2 + val thetaVal: Float32 = theta + val startPosVal: Int32 = layout.params.read.startPos + + GIO.when(idx < totalPairsVal): + val perHead = halfHead + val perPos = numHeadsVal * halfHead + val perBatch = Tval * perPos + + val b = idx / perBatch + val rem1 = idx.mod(perBatch) + val t = rem1 / perPos + val rem2 = rem1.mod(perPos) + val h = rem2 / perHead + val d = rem2.mod(perHead) + + val pos = startPosVal + t + val headSizeFloat: Float32 = headSize.toFloat + val freqExponent: Float32 = -2.0f * d.asFloat / headSizeFloat + val freq: Float32 = pos.asFloat * pow(thetaVal, freqExponent) + val cosFreq = cos(freq).asFloat16 + val sinFreq = sin(freq).asFloat16 + + val fullIdx: Int32 = b * Tval * numHeadsVal * headSize + t * numHeadsVal * headSize + h * headSize + val idx0 = fullIdx + d * 2 + val idx1 = idx0 + 1 + + val x0 = GIO.read[Float16](layout.input, idx0) + val x1 = GIO.read[Float16](layout.input, idx1) + val y0 = x0 * cosFreq - x1 * sinFreq + val y1 = x0 * sinFreq + x1 * cosFreq + + for + _ <- GIO.write[Float16](layout.output, idx0, y0) + _ <- GIO.write[Float16](layout.output, idx1, y1) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala new file mode 100644 index 00000000..b7499baf --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala @@ -0,0 +1,39 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 SwiGLU activation: `SiLU(gate) * up`. + * + * Combines gated linear unit with SiLU (swish) activation. + * Computes in F32 internally for precision. + */ +object F16SwiGLUProgram: + case class Sizes(numElements: Int) + + case class ProgramLayout( + gate: GBuffer[Float16], + up: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + gate = GBuffer[Float16](s.numElements), + up = GBuffer[Float16](s.numElements), + output = GBuffer[Float16](s.numElements), + ), + dispatch = (_, s) => StaticDispatch(((s.numElements + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val tid = GIO.invocationId + GIO.when(tid < sizes.numElements): + val g = GIO.read[Float16](layout.gate, tid).asFloat32 + val u = GIO.read[Float16](layout.up, tid).asFloat32 + val sigmoidG = 1.0f / (1.0f + exp(-g)) + val result = g * sigmoidG * u + GIO.write[Float16](layout.output, tid, result.asFloat16) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala new file mode 100644 index 00000000..893c40f4 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala @@ -0,0 +1,4 @@ +package io.computenode.cyfra.llama.programs + +/** F16 (Float16) precision programs for Llama inference */ +package object f16 diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/EmbeddingProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/EmbeddingProgram.scala new file mode 100644 index 00000000..f9dcfaf0 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/EmbeddingProgram.scala @@ -0,0 +1,49 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} + +/** Embedding lookup program: output = embeddings[tokens] */ +object EmbeddingProgram: + val BLOCK_SIZE = 256 + + case class Sizes( + seqLen: Int, + hiddenSize: Int, + vocabSize: Int, + ): + def totalOutputs: Int = seqLen * hiddenSize + def embeddingSize: Int = vocabSize * hiddenSize + + case class ProgramLayout( + tokens: GBuffer[Int32], + embeddings: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val seqLen = sizes.seqLen + val hiddenSize = sizes.hiddenSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + tokens = GBuffer[Int32](s.seqLen), + embeddings = GBuffer[Float32](s.embeddingSize), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch(((s.totalOutputs + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val hiddenSizeVal: Int32 = hiddenSize + val totalVal: Int32 = seqLen * hiddenSize + + GIO.when(idx < totalVal): + val tokenPos = idx / hiddenSizeVal + val dim = idx.mod(hiddenSizeVal) + val tokenId = GIO.read[Int32](layout.tokens, tokenPos) + val embIdx = tokenId * hiddenSizeVal + dim + val value = GIO.read[Float32](layout.embeddings, embIdx) + GIO.write[Float32](layout.output, idx, value) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteK.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteK.scala new file mode 100644 index 00000000..6a21d26a --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteK.scala @@ -0,0 +1,87 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Writes K vectors to KV cache at specified positions (F32 version). + * + * Each invocation copies one element from the input K tensor to the KV cache + * at the position specified by the runtime `startPos` parameter. + * + * @note The cache is organized as (L × maxSeqLen × NKV × headSize) where L is total layers. + * This program writes to a single layer's slice using `cacheLayerOffset`. + */ +object KVCacheWriteK: + + /** Compile-time size parameters for the KV cache write K program. */ + case class Sizes( + B: Int, + T: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + layer: Int, + posOffset: Int, + cacheLayerOffset: Int, + L: Int, + ): + def totalElements: Int = B * T * NKV * headSize + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + k: GBuffer[Float32], + kCache: GBuffer[Float32], + params: GUniform[AttentionParams], + ) derives Layout + + /** Creates a GPU program that writes K vectors to the KV cache. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NKV = sizes.NKV + val headSize = sizes.headSize + val totalElements = sizes.totalElements + val cacheLayerOffset = sizes.cacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + k = GBuffer[Float32](s.B * s.T * s.NKV * s.headSize), + kCache = GBuffer[Float32](s.fullCacheSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.totalElements, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val posOffsetVal: Int32 = layout.params.read.startPos + + val Tval: Int32 = T + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val totalElementsVal: Int32 = totalElements + val cacheLayerOffsetVal: Int32 = cacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + GIO.when(idx < totalElementsVal): + val elementsPerBatch = Tval * NKVval * headSizeVal + val b = idx / elementsPerBatch + val remaining1 = idx.mod(elementsPerBatch) + val elementsPerPos = NKVval * headSizeVal + val t = remaining1 / elementsPerPos + val remaining2 = remaining1.mod(elementsPerPos) + val h = remaining2 / headSizeVal + val d = remaining2.mod(headSizeVal) + + val inputIdx = idx + val kVal = GIO.read[Float32](layout.k, inputIdx) + + val cachePos = posOffsetVal + t + val cacheIdx = cacheLayerOffsetVal + cachePos * kvSizePerPosVal + h * headSizeVal + d + GIO.write[Float32](layout.kCache, cacheIdx, kVal) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteV.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteV.scala new file mode 100644 index 00000000..8f79a013 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCacheWriteV.scala @@ -0,0 +1,83 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Writes V vectors to KV cache at specified positions (F32 version). + * + * Identical structure to KVCacheWriteK but operates on V vectors. + */ +object KVCacheWriteV: + + /** Compile-time size parameters for the KV cache write V program. */ + case class Sizes( + B: Int, + T: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + layer: Int, + posOffset: Int, + cacheLayerOffset: Int, + L: Int, + ): + def totalElements: Int = B * T * NKV * headSize + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + v: GBuffer[Float32], + vCache: GBuffer[Float32], + params: GUniform[AttentionParams], + ) derives Layout + + /** Creates a GPU program that writes V vectors to the KV cache. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NKV = sizes.NKV + val headSize = sizes.headSize + val totalElements = sizes.totalElements + val cacheLayerOffset = sizes.cacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + v = GBuffer[Float32](s.B * s.T * s.NKV * s.headSize), + vCache = GBuffer[Float32](s.fullCacheSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.totalElements, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val posOffsetVal: Int32 = layout.params.read.startPos + + val Tval: Int32 = T + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val totalElementsVal: Int32 = totalElements + val cacheLayerOffsetVal: Int32 = cacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + GIO.when(idx < totalElementsVal): + val elementsPerBatch = Tval * NKVval * headSizeVal + val b = idx / elementsPerBatch + val remaining1 = idx.mod(elementsPerBatch) + val elementsPerPos = NKVval * headSizeVal + val t = remaining1 / elementsPerPos + val remaining2 = remaining1.mod(elementsPerPos) + val h = remaining2 / headSizeVal + val d = remaining2.mod(headSizeVal) + + val inputIdx = idx + val vVal = GIO.read[Float32](layout.v, inputIdx) + + val cachePos = posOffsetVal + t + val cacheIdx = cacheLayerOffsetVal + cachePos * kvSizePerPosVal + h * headSizeVal + d + GIO.write[Float32](layout.vCache, cacheIdx, vVal) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCachedAttention.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCachedAttention.scala new file mode 100644 index 00000000..b7b375c7 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/KVCachedAttention.scala @@ -0,0 +1,172 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.GShared +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.llama.programs.AttentionParams + +/** KV-cached attention for incremental inference (F32 version). + * + * Computes attention by reading Q from current tokens and K/V from the full cache. + * One workgroup handles one (batch, query_position, head) tuple. + * Supports grouped-query attention (GQA) where multiple Q heads share K/V heads. + */ +object KVCachedAttention: + val WARP_SIZE = 32 + val MAX_SEQ_LEN = 2048 + + /** Compile-time size parameters for the attention program. */ + case class Sizes( + B: Int, + T: Int, + NH: Int, + NKV: Int, + headSize: Int, + startPos: Int, + kCacheLayerOffset: Int, + vCacheLayerOffset: Int, + L: Int, + maxSeqLen: Int, + ): + def gqaRatio: Int = NH / NKV + def numScoreIterations: Int = (maxSeqLen + WARP_SIZE - 1) / WARP_SIZE + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + q: GBuffer[Float32], + kCache: GBuffer[Float32], + vCache: GBuffer[Float32], + output: GBuffer[Float32], + params: GUniform[AttentionParams], + ) derives Layout + + /** Creates a GPU program for KV-cached attention computation. */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val scoresShared = GShared[Float32](MAX_SEQ_LEN) + + val B = sizes.B + val T = sizes.T + val NH = sizes.NH + val NKV = sizes.NKV + val headSize = sizes.headSize + val gqaRatio = sizes.gqaRatio + val scale = 1.0f / math.sqrt(headSize).toFloat + val numScoreIterations = sizes.numScoreIterations + val kCacheLayerOffset = sizes.kCacheLayerOffset + val vCacheLayerOffset = sizes.vCacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + val maxSeqLen = sizes.maxSeqLen + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + q = GBuffer[Float32](s.B * s.T * s.NH * s.headSize), + kCache = GBuffer[Float32](s.fullCacheSize), + vCache = GBuffer[Float32](s.fullCacheSize), + output = GBuffer[Float32](s.B * s.T * s.NH * s.headSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.B * s.T * s.NH, 1, 1)), + workgroupSize = (WARP_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + + val runtimeParams = layout.params.read + val seqLenVal: Int32 = runtimeParams.seqLen + val startPosVal: Int32 = runtimeParams.startPos + + val Tval: Int32 = T + val NHval: Int32 = NH + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val gqaRatioVal: Int32 = gqaRatio + val scaleVal: Float32 = scale + val kCacheLayerOffsetVal: Int32 = kCacheLayerOffset + val vCacheLayerOffsetVal: Int32 = vCacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + val posPerBatch = Tval * NHval + val batchIdx = workgroupId / posPerBatch + val posInBatch = workgroupId.mod(posPerBatch) + val queryPosLocal = posInBatch / NHval + val headIdx = posInBatch.mod(NHval) + val kvHeadIdx = headIdx / gqaRatioVal + + val queryPosGlobal = startPosVal + queryPosLocal + val qBase = batchIdx * Tval * NHval * headSizeVal + queryPosLocal * NHval * headSizeVal + headIdx * headSizeVal + + // Phase 1: Compute attention scores Q·K and track max for numerical stability + val computeScoresAndMax: GIO[Float32] = GIO.foldRepeat[Float32](numScoreIterations, -10000.0f): (iter, localMax) => + val kPos = tid + iter * WARP_SIZE + val isValid = kPos <= queryPosGlobal && kPos < seqLenVal + val kCacheBase = kCacheLayerOffsetVal + kPos * kvSizePerPosVal + kvHeadIdx * headSizeVal + + val dot = GSeq.gen[Int32](0, _ + 1).limit(headSize).fold(0.0f, (acc: Float32, d: Int32) => + val qVal = GIO.read[Float32](layout.q, qBase + d) + val kVal = GIO.read[Float32](layout.kCache, kCacheBase + d) + acc + qVal * kVal + ) + + val score = when(isValid)(dot * scaleVal).otherwise(-10000.0f) + + for _ <- scoresShared.write(kPos, score) yield + when(isValid)(max(localMax, score)).otherwise(localMax) + + for + localMax <- computeScoresAndMax + _ <- GIO.barrier + + // Phase 2: Compute softmax numerator exp(score - max) + globalMax <- GIO.pure(GIO.subgroupMax(localMax)) + localSum <- GIO.foldRepeat[Float32](numScoreIterations, 0.0f): (iter, sum) => + val kPos = tid + iter * WARP_SIZE + val isValid = kPos <= queryPosGlobal && kPos < seqLenVal + val score = scoresShared.read(kPos) + val expScore = exp(score - globalMax) + for _ <- scoresShared.write(kPos, expScore) yield + when(isValid)(sum + expScore).otherwise(sum) + + _ <- GIO.barrier + + // Phase 3: Normalize to get attention weights + globalSum <- GIO.pure(GIO.subgroupAdd(localSum) + 0.0000001f) + _ <- GIO.foldRepeat[Empty](numScoreIterations, Empty()): (iter, _) => + val kPos = tid + iter * WARP_SIZE + val isValid = kPos <= queryPosGlobal && kPos < seqLenVal + val expScore = scoresShared.read(kPos) + for _ <- GIO.when(isValid)(scoresShared.write(kPos, expScore / globalSum)) yield Empty() + + _ <- GIO.barrier + + // Phase 4: Compute weighted sum of V values + outDim1 <- GIO.pure(tid) + _ <- GIO.when(outDim1 < headSizeVal): + val weightedSum1 = GSeq.gen[Int32](0, _ + 1).limit(maxSeqLen).takeWhile(_ < seqLenVal).fold(0.0f, (sum: Float32, kPos: Int32) => + val isValid = kPos <= queryPosGlobal + val weight = scoresShared.read(kPos) + val vCacheBase = vCacheLayerOffsetVal + kPos * kvSizePerPosVal + kvHeadIdx * headSizeVal + val vVal = GIO.read[Float32](layout.vCache, vCacheBase + outDim1) + when(isValid)(sum + weight * vVal).otherwise(sum) + ) + val outBase = batchIdx * Tval * NHval * headSizeVal + queryPosLocal * NHval * headSizeVal + headIdx * headSizeVal + GIO.write[Float32](layout.output, outBase + outDim1, weightedSum1) + + // Handle dimensions beyond WARP_SIZE (for headSize > 32) + outDim2 <- GIO.pure(tid + WARP_SIZE) + _ <- GIO.when(outDim2 < headSizeVal): + val weightedSum2 = GSeq.gen[Int32](0, _ + 1).limit(maxSeqLen).takeWhile(_ < seqLenVal).fold(0.0f, (sum: Float32, kPos: Int32) => + val isValid = kPos <= queryPosGlobal + val weight = scoresShared.read(kPos) + val vCacheBase = vCacheLayerOffsetVal + kPos * kvSizePerPosVal + kvHeadIdx * headSizeVal + val vVal = GIO.read[Float32](layout.vCache, vCacheBase + outDim2) + when(isValid)(sum + weight * vVal).otherwise(sum) + ) + val outBase = batchIdx * Tval * NHval * headSizeVal + queryPosLocal * NHval * headSizeVal + headIdx * headSizeVal + GIO.write[Float32](layout.output, outBase + outDim2, weightedSum2) + yield Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q4KMatmulVecProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q4KMatmulVecProgram.scala new file mode 100644 index 00000000..d3aee4b7 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q4KMatmulVecProgram.scala @@ -0,0 +1,342 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct.Empty + +/** Q4_K Matrix-Vector Multiplication with inline GPU dequantization. + * + * This is the high-performance version that keeps weights quantized on GPU. + * 8x memory bandwidth reduction vs F32. + * + * Q4_K format (144 bytes per 256-element super-block): + * - 2 bytes: d (fp16) - main scale + * - 2 bytes: dmin (fp16) - min scale + * - 12 bytes: scales (6-bit packed, 8 scales + 8 mins) + * - 128 bytes: quantized values (4-bit packed, 256 values) + */ +object Q4KMatmulVecProgram: + val WARP_SIZE = 32 + val QK_K = 256 // Elements per quantization super-block + val BLOCK_BYTES = 144 // Bytes per Q4_K block + val UINT32_PER_BLOCK = 36 // uint32 per block (144/4) + val BLOCK_SIZE = 256 // Workgroup size + val NUM_ROWS = 8 // Outputs per workgroup + + case class Sizes( + batchSize: Int, // B * T + inFeatures: Int, // K - must be multiple of QK_K (256) + outFeatures: Int, // N + ): + require(inFeatures % QK_K == 0, s"inFeatures must be multiple of $QK_K") + def totalOutputs: Int = batchSize * outFeatures + def numQBlocks: Int = inFeatures / QK_K + def numWorkgroups: Int = (totalOutputs + NUM_ROWS - 1) / NUM_ROWS + + case class ProgramLayout( + weight: GBuffer[UInt32], // Packed Q4_K: [outFeatures * numQBlocks * 36] + input: GBuffer[Float32], // [batchSize, inFeatures] + output: GBuffer[Float32], // [batchSize, outFeatures] + ) derives Layout + + /** Convert fp16 (half precision) to fp32. + * Handles normalized, denormalized, and zero values. + */ + private def fp16ToFp32(h: UInt32): Float32 = + val mask1: UInt32 = 1 + val zero: UInt32 = 0 + val expMask: UInt32 = 0x1F + val mantMask: UInt32 = 0x3FF + val sign = (h >> 15.unsigned) & mask1 + val exp = (h >> 10.unsigned) & expMask + val mant = h & mantMask + + // For normalized numbers (exp > 0): (1 + mant/1024) * 2^(exp-15) + // For denormalized (exp == 0): (mant/1024) * 2^(-14) = mant * 2^(-24) + // For zero: 0 + val expIsZero = exp === zero + val mantIsZero = mant === zero + + val normMantF = 1.0f + mant.asFloat / 1024.0f + val normExpF = exp.signed - 15 + val normResult = normMantF * pow(2.0f, normExpF.asFloat) + val denormResult = mant.asFloat * 5.9604645e-8f // 2^(-24) + + val magnitude = when(expIsZero): + when(mantIsZero)(0.0f).otherwise(denormResult) + .otherwise: + normResult + + when(sign === mask1)(-magnitude).otherwise(magnitude) + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val outFeatures = sizes.outFeatures + val numQBlocks = sizes.numQBlocks + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[UInt32](s.outFeatures * s.numQBlocks * UINT32_PER_BLOCK), + input = GBuffer[Float32](s.batchSize * s.inFeatures), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpInWorkgroup = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val outFeaturesVal: Int32 = outFeatures + val numQBlocksVal: Int32 = numQBlocks + val totalOutputsVal: Int32 = sizes.totalOutputs + + // Each warp computes one output + val outputIdx = workgroupId * NUM_ROWS + warpInWorkgroup + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + // Process all Q4_K blocks for this output row + // Each lane handles elements laneId, laneId+32, ... within each block + // Using foldRepeat over Q-blocks (compile-time count) + val partialSum = GIO.foldRepeat[Float32](numQBlocks, 0.0f): (qBlock, acc) => + val blockBase = (outIdx * numQBlocksVal + qBlock) * UINT32_PER_BLOCK + val colBase = qBlock * QK_K + + // Read d and dmin from first uint32 (two fp16 packed) + val dminPacked = GIO.read[UInt32](layout.weight, blockBase) + val mask16: UInt32 = 0xFFFF + val d = fp16ToFp32(dminPacked & mask16) + val dmin = fp16ToFp32((dminPacked >> 16.unsigned) & mask16) + + // Each lane processes 8 elements: laneId, laneId+32, ... + // Use GSeq with unroll hint to generate compact [[unroll]] loop instead of massive inline expansion + val blockSum = GSeq + .gen[Int32](0, _ + 1) + .map { i => + val localIdx = laneId + i * 32 + processQ4KElement(layout, blockBase, colBase, batch, inFeaturesVal, d, dmin, localIdx) + } + .limit(8) + .unroll // Generates [[unroll]] pragma - keeps code compact + .fold(0.0f, (acc: Float32, x: Float32) => acc + x) + + for _ <- GIO.barrier yield + acc + blockSum + + // IMPORTANT: subgroupAdd must be called by ALL lanes (it's a collective operation) + // Do NOT wrap subgroupAdd in GIO.when - that breaks the collective reduction! + partialSum.flatMap: ps => + val totalSum: Float32 = GIO.subgroupAdd(ps) + // All lanes write (safe, coalesced) - like F32 MatmulVecProgram + // Guard only the write, not the subgroupAdd + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float32](layout.output, outputIdx, totalSum) + + private def processQ4KElement( + layout: ProgramLayout, + blockBase: Int32, + colBase: Int32, + batch: Int32, + inFeaturesVal: Int32, + d: Float32, + dmin: Float32, + localIdx: Int32, + ): Float32 = + processQ4KElementGeneric(layout.weight, layout.input, blockBase, colBase, batch, inFeaturesVal, d, dmin, localIdx) + + /** Generic Q4K element processor that works with any UInt32 weight buffer. + * + * Matches llama.cpp dequant_q4_k.comp exactly: + * - 256 elements per block, processed in 8 sub-blocks of 32 + * - subBlock = localIdx / 32 (0-7), is the scale index + * - Elements in even sub-blocks (0,2,4,6) use LOW nibble + * - Elements in odd sub-blocks (1,3,5,7) use HIGH nibble + * - Scale/min extraction uses is < 4 vs is >= 4 logic (NOT even/odd!) + * + * Q4_K layout (144 bytes): + * bytes 0-1: d (fp16) + * bytes 2-3: dmin (fp16) + * bytes 4-15: scales (12 bytes, complex 6-bit packing) + * bytes 16-143: qs (128 bytes, 4-bit quantized values) + * + * Scale extraction (from llama.cpp get_scale_min_k4): + * For is < 4: sc = scales[is] & 0x3F, m = scales[is+4] & 0x3F + * For is >= 4: sc = (scales[is+4] & 0x0F) | ((scales[is-4] >> 6) << 4) + * m = ((scales[is+4] >> 4) & 0x0F) | ((scales[is] >> 6) << 4) + */ + private[f32] def processQ4KElementGeneric( + weight: GBuffer[UInt32], + input: GBuffer[Float32], + blockBase: Int32, + colBase: Int32, + batch: Int32, + inFeaturesVal: Int32, + d: Float32, + dmin: Float32, + localIdx: Int32, + ): Float32 = + val valid = localIdx < QK_K + + // Map localIdx to sub-block structure + val subBlock = localIdx / 32 // 0-7 (is) + val posInSub = localIdx.mod(32) // 0-31 + + // Is this a high-nibble sub-block? (odd sub-blocks: 1, 3, 5, 7) + val isHighNibble = (subBlock.mod(2)) === 1 + + // qs byte index: (subBlock / 2) * 32 + posInSub + val j: Int32 = subBlock / 2 // 0-3 (which group of 64) + val qsByteIdx = j * 32 + posInSub + + // Read qs byte from bytes 16-143 + val sixteen: Int32 = 16 + val qsOffset: Int32 = sixteen + qsByteIdx + val qsUint32Idx = qsOffset / 4 + val qsByteInUint32 = qsOffset.mod(4) + val qsWord = GIO.read[UInt32](weight, blockBase + qsUint32Idx) + val byteMask: UInt32 = 0xFF + val nibbleMask: UInt32 = 0x0F + val qsByte = (qsWord >> (qsByteInUint32 * 8).unsigned) & byteMask + + // Extract 4-bit quantized value + val q = when(isHighNibble)((qsByte >> 4.unsigned) & nibbleMask).otherwise(qsByte & nibbleMask) + + // Scale index = subBlock (is = 0-7) + val is: Int32 = subBlock + val isLt4 = is < 4 + + // Helper to read scale byte (scales at bytes 4-15) + val four: Int32 = 4 + def readScaleByte(scaleIdx: Int32): UInt32 = + val byteOffset: Int32 = four + scaleIdx + val uint32Idx = byteOffset / 4 + val byteInUint32 = byteOffset.mod(4) + val word = GIO.read[UInt32](weight, blockBase + uint32Idx) + (word >> (byteInUint32 * 8).unsigned) & byteMask + + val mask3F: UInt32 = 0x3F + val mask0F: UInt32 = 0x0F + + // Read scale bytes needed for both branches + // For is < 4: need scales[is] and scales[is+4] + // For is >= 4: need scales[is+4], scales[is-4], and scales[is] + val scalesIs = readScaleByte(is) // scales[is] + val scalesIsP4 = readScaleByte(is + 4) // scales[is+4] + val scalesIsM4 = readScaleByte(is - 4) // scales[is-4] (only valid when is >= 4) + + // Compute scale (sc) and min (m) based on is < 4 vs is >= 4 + // For is < 4: sc = scales[is] & 0x3F, m = scales[is+4] & 0x3F + val scLt4 = scalesIs & mask3F + val mLt4 = scalesIsP4 & mask3F + + // For is >= 4: sc = (scales[is+4] & 0x0F) | ((scales[is-4] >> 6) << 4) + // m = ((scales[is+4] >> 4) & 0x0F) | ((scales[is] >> 6) << 4) + val scGe4 = (scalesIsP4 & mask0F) | ((scalesIsM4 >> 6.unsigned) << 4.unsigned) + val mGe4 = ((scalesIsP4 >> 4.unsigned) & mask0F) | ((scalesIs >> 6.unsigned) << 4.unsigned) + + val sc = when(isLt4)(scLt4.asFloat).otherwise(scGe4.asFloat) + val m = when(isLt4)(mLt4.asFloat).otherwise(mGe4.asFloat) + + // Dequantize: result = d * sc * q - dmin * m + val dequantized = d * sc * q.asFloat - dmin * m + + // Read input + val col = colBase + localIdx + val inputVal = GIO.read[Float32](input, batch * inFeaturesVal + col) + + when(valid)(dequantized * inputVal).otherwise(0.0f) + + // ============= Layered Version for Pipeline ============= + + /** Layered Q4K matmul with weight offset for pipeline integration. + * + * This version allows concatenated Q4_K weights across layers. + */ + object Layered: + case class Sizes( + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + weightOffsetUint32: Int, // Offset in uint32 into the concatenated weight buffer + totalWeightUint32: Int, // Total size of concatenated weight buffer + ): + require(inFeatures % QK_K == 0, s"inFeatures must be multiple of $QK_K") + def totalOutputs: Int = batchSize * outFeatures + def numQBlocks: Int = inFeatures / QK_K + def numWorkgroups: Int = (totalOutputs + NUM_ROWS - 1) / NUM_ROWS + def uint32PerRow: Int = numQBlocks * UINT32_PER_BLOCK + + case class ProgramLayout( + weight: GBuffer[UInt32], + input: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[UInt32](s.totalWeightUint32), + input = GBuffer[Float32](s.batchSize * s.inFeatures), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + forwardBody(sizes, layout) + + /** Body of the forward pass - can be called from other programs. */ + def forwardBody(sizes: Sizes, layout: ProgramLayout): GIO[Empty] = + val inFeatures = sizes.inFeatures + val outFeatures = sizes.outFeatures + val numQBlocks = sizes.numQBlocks + val weightOffset = sizes.weightOffsetUint32 + val uint32PerRow = sizes.uint32PerRow + + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpInWorkgroup = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val outFeaturesVal: Int32 = outFeatures + val numQBlocksVal: Int32 = numQBlocks + val totalOutputsVal: Int32 = sizes.totalOutputs + val weightOffsetVal: Int32 = weightOffset + val uint32PerRowVal: Int32 = uint32PerRow + + val outputIdx = workgroupId * NUM_ROWS + warpInWorkgroup + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + val partialSum = GIO.foldRepeat[Float32](numQBlocks, 0.0f): (qBlock, acc) => + // Apply weight offset for layered weights + val blockBase = weightOffsetVal + outIdx * uint32PerRowVal + qBlock * UINT32_PER_BLOCK + val colBase = qBlock * QK_K + + val dminPacked = GIO.read[UInt32](layout.weight, blockBase) + val mask16: UInt32 = 0xFFFF + val d = fp16ToFp32(dminPacked & mask16) + val dmin = fp16ToFp32((dminPacked >> 16.unsigned) & mask16) + + // Use GSeq with unroll hint for compact code generation + val blockSum = GSeq + .gen[Int32](0, _ + 1) + .map { i => + val localIdx = laneId + i * 32 + processQ4KElementGeneric(layout.weight, layout.input, blockBase, colBase, batch, inFeaturesVal, d, dmin, localIdx) + } + .limit(8) + .unroll + .fold(0.0f, (acc: Float32, x: Float32) => acc + x) + + for _ <- GIO.barrier yield + acc + blockSum + + // IMPORTANT: subgroupAdd must be called by ALL lanes (it's a collective operation) + partialSum.flatMap: ps => + val totalSum: Float32 = GIO.subgroupAdd(ps) + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float32](layout.output, outputIdx, totalSum) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q6KMatmulVecProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q6KMatmulVecProgram.scala new file mode 100644 index 00000000..97cb8471 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/Q6KMatmulVecProgram.scala @@ -0,0 +1,405 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct.Empty + +/** Q6_K Matrix-Vector Multiplication with inline GPU dequantization. + * + * Q6_K format (210 bytes per 256-element super-block): + * - 128 bytes: ql - low 4 bits (packed 2 per byte) + * - 64 bytes: qh - high 2 bits (packed 4 per byte) + * - 16 bytes: scales (int8, one per 16 elements) + * - 2 bytes: d (fp16) - main scale + */ +object Q6KMatmulVecProgram: + val WARP_SIZE = 32 + val QK_K = 256 // Elements per quantization super-block + val BLOCK_BYTES = 210 // Bytes per Q6_K block + val UINT32_PER_BLOCK = 53 // ceil(210/4) - actually 52.5, so we use 53 uint32 with padding + val BLOCK_SIZE = 256 // Workgroup size + val NUM_ROWS = 8 // Outputs per workgroup + + case class Sizes( + batchSize: Int, // B * T + inFeatures: Int, // K - must be multiple of QK_K (256) + outFeatures: Int, // N + ): + require(inFeatures % QK_K == 0, s"inFeatures must be multiple of $QK_K") + def totalOutputs: Int = batchSize * outFeatures + def numQBlocks: Int = inFeatures / QK_K + def numWorkgroups: Int = (totalOutputs + NUM_ROWS - 1) / NUM_ROWS + // Q6_K uses 210 bytes per block = 52.5 uint32, so we pack at byte level + def bytesPerRow: Int = numQBlocks * BLOCK_BYTES + + case class ProgramLayout( + weight: GBuffer[UInt32], // Packed Q6_K: stored as bytes packed in uint32 + input: GBuffer[Float32], // [batchSize, inFeatures] + output: GBuffer[Float32], // [batchSize, outFeatures] + ) derives Layout + + /** Convert fp16 (half precision) to fp32. + * Handles normalized, denormalized, and zero values. + */ + private def fp16ToFp32(h: UInt32): Float32 = + val mask1: UInt32 = 1 + val zero: UInt32 = 0 + val expMask: UInt32 = 0x1F + val mantMask: UInt32 = 0x3FF + val sign = (h >> 15.unsigned) & mask1 + val exp = (h >> 10.unsigned) & expMask + val mant = h & mantMask + + // For normalized numbers (exp > 0): (1 + mant/1024) * 2^(exp-15) + // For denormalized (exp == 0): (mant/1024) * 2^(-14) + // For zero: 0 + val expIsZero = exp === zero + val mantIsZero = mant === zero + + // Normalized case + val normMantF = 1.0f + mant.asFloat / 1024.0f + val normExpF = exp.signed - 15 + val normResult = normMantF * pow(2.0f, normExpF.asFloat) + + // Denormalized case: mant/1024 * 2^(-14) = mant * 2^(-24) + val denormResult = mant.asFloat * 5.9604645e-8f // 2^(-24) + + // Select result based on exp and mant + val magnitude = when(expIsZero): + when(mantIsZero)(0.0f).otherwise(denormResult) + .otherwise: + normResult + + when(sign === mask1)(-magnitude).otherwise(magnitude) + + /** Read a byte from the weight buffer at a given byte offset. */ + private def readByte(weight: GBuffer[UInt32], byteOffset: Int32): UInt32 = + val uint32Idx = byteOffset / 4 + val byteInUint32 = byteOffset.mod(4) + val word = GIO.read[UInt32](weight, uint32Idx) + val byteMask: UInt32 = 0xFF + (word >> (byteInUint32 * 8).unsigned) & byteMask + + /** Read a signed byte from the weight buffer. */ + private def readSignedByte(weight: GBuffer[UInt32], byteOffset: Int32): Int32 = + val b = readByte(weight, byteOffset) + // Convert unsigned byte to signed: if > 127, subtract 256 + val mask127: UInt32 = 127 + when(b > mask127)(b.signed - 256).otherwise(b.signed) + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val outFeatures = sizes.outFeatures + val numQBlocks = sizes.numQBlocks + val bytesPerRow = sizes.bytesPerRow + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[UInt32]((s.outFeatures * s.bytesPerRow + 3) / 4), + input = GBuffer[Float32](s.batchSize * s.inFeatures), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpInWorkgroup = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val outFeaturesVal: Int32 = outFeatures + val numQBlocksVal: Int32 = numQBlocks + val totalOutputsVal: Int32 = sizes.totalOutputs + val bytesPerRowVal: Int32 = bytesPerRow + + // Each warp computes one output + val outputIdx = workgroupId * NUM_ROWS + warpInWorkgroup + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + // Process all Q6_K blocks for this output row + val partialSum = GIO.foldRepeat[Float32](numQBlocks, 0.0f): (qBlock, acc) => + val blockByteBase = outIdx * bytesPerRowVal + qBlock * BLOCK_BYTES + val colBase = qBlock * QK_K + + // Q6_K block layout (210 bytes): + // - ql: bytes 0-127 (low 4 bits) + // - qh: bytes 128-191 (high 2 bits) + // - scales: bytes 192-207 (int8) + // - d: bytes 208-209 (fp16) + + // Read d (fp16 at byte 208) + val dWord = GIO.read[UInt32](layout.weight, (blockByteBase + 208) / 4) + val dBytePos = (blockByteBase + 208).mod(4) + val mask16: UInt32 = 0xFFFF + val dHalf = (dWord >> (dBytePos * 8).unsigned) & mask16 + val d = fp16ToFp32(dHalf) + + // Each lane processes 8 elements: laneId, laneId+32, ..., laneId+224 + val iter0 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId) + val iter1 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 32) + val iter2 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 64) + val iter3 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 96) + val iter4 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 128) + val iter5 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 160) + val iter6 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 192) + val iter7 = processQ6KElement(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 224) + + for _ <- GIO.barrier yield + acc + iter0 + iter1 + iter2 + iter3 + iter4 + iter5 + iter6 + iter7 + + // IMPORTANT: subgroupAdd must be called by ALL lanes (it's a collective operation) + partialSum.flatMap: ps => + val totalSum: Float32 = GIO.subgroupAdd(ps) + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float32](layout.output, outputIdx, totalSum) + + /** Process a single Q6_K element. + * + * Q6_K dequantization (matches llama.cpp dequant_q6_k.comp exactly): + * - Block has 256 elements split into 2 halves of 128 + * - For element localIdx in [0, 255]: + * - ip = localIdx / 128 (which half: 0 or 1) + * - il = localIdx % 128 (position within half: 0-127) + * - is = ip * 8 + il / 16 (base scale index) + * - y_idx = 128 * ip + il (output position) + * - ql_idx = 64 * ip + il (ql byte index for low nibble) + * - qh is at: 32 * ip + il (one byte per pair of ql elements) + * + * The encoding: + * - Output[y_idx + 0]: ql[ql_idx] low nibble + qh bits 0-1 + * - Output[y_idx + 32]: ql[ql_idx + 32] low nibble + qh bits 2-3 + * - Output[y_idx + 64]: ql[ql_idx] high nibble + qh bits 4-5 + * - Output[y_idx + 96]: ql[ql_idx + 32] high nibble + qh bits 6-7 + * + * Since we process elements sequentially (localIdx = 0..255), we need + * to reverse-map from localIdx to determine which case we're in. + */ + private def processQ6KElement( + weight: GBuffer[UInt32], + input: GBuffer[Float32], + blockByteBase: Int32, + colBase: Int32, + batch: Int32, + inFeaturesVal: Int32, + d: Float32, + localIdx: Int32, + ): Float32 = + val valid = localIdx < QK_K + + // Determine which half (ip) and position within the 128-value range + val ip = localIdx / 128 // 0 or 1 + val posIn128 = localIdx.mod(128) // 0-127 + + // Each 128-value half is organized as 4 groups of 32: + // Group 0: positions 0-31 (y_idx + 0) + // Group 1: positions 32-63 (y_idx + 32) + // Group 2: positions 64-95 (y_idx + 64) + // Group 3: positions 96-127 (y_idx + 96) + val group = posIn128 / 32 // 0, 1, 2, or 3 + val il = posIn128.mod(32) // 0-31 (position within group) + + // Compute ql_idx and qh_idx (the "source" indices in llama.cpp) + val ql_idx_base = ip * 64 + il + val qh_idx = ip * 32 + il + + // Read the qh byte + val qhByteOffset = blockByteBase + 128 + qh_idx + val qhByte = readByte(weight, qhByteOffset) + + // Determine which ql byte and nibble based on group + // Group 0: ql[ql_idx + 0], low nibble, qh bits 0-1 + // Group 1: ql[ql_idx + 32], low nibble, qh bits 2-3 + // Group 2: ql[ql_idx + 0], high nibble, qh bits 4-5 + // Group 3: ql[ql_idx + 32], high nibble, qh bits 6-7 + val ql_idx = when(group === 0 || group === 2)(ql_idx_base).otherwise(ql_idx_base + 32) + val useHighNibble = group >= 2 + val qhShift = group * 2 + + val qlByteOffset = blockByteBase + ql_idx + val qlByte = readByte(weight, qlByteOffset) + val nibbleMask: UInt32 = 0x0F + val ql = when(useHighNibble)((qlByte >> 4.unsigned) & nibbleMask).otherwise(qlByte & nibbleMask) + + val qhMask: UInt32 = 0x03 + val qh = (qhByte >> qhShift.unsigned) & qhMask + + // Combine to get 6-bit value and subtract 32 for signed + val q6 = ql | (qh << 4.unsigned) + val qSigned = q6.signed - 32 + + // Compute scale index: is = 8 * ip + il / 16 + // For each group, we use a different scale offset: + // Group 0 uses is + 0, Group 1 uses is + 2, Group 2 uses is + 4, Group 3 uses is + 6 + val is_base = ip * 8 + il / 16 + val scaleOffset = group * 2 + val scaleIdx = is_base + scaleOffset + val scaleByteOffset = blockByteBase + 192 + scaleIdx + val scale = readSignedByte(weight, scaleByteOffset) + + // Dequantize + val dequantized = d * scale.asFloat * qSigned.asFloat + + // Read input + val col = colBase + localIdx + val inputVal = GIO.read[Float32](input, batch * inFeaturesVal + col) + + when(valid)(dequantized * inputVal).otherwise(0.0f) + + // ============= Layered Version for Pipeline ============= + + object Layered: + case class Sizes( + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + weightOffsetBytes: Int, // Offset in BYTES into the concatenated weight buffer + totalWeightBytes: Int, // Total size of concatenated weight buffer in bytes + ): + require(inFeatures % QK_K == 0, s"inFeatures must be multiple of $QK_K") + def totalOutputs: Int = batchSize * outFeatures + def numQBlocks: Int = inFeatures / QK_K + def numWorkgroups: Int = (totalOutputs + NUM_ROWS - 1) / NUM_ROWS + def bytesPerRow: Int = numQBlocks * BLOCK_BYTES + + case class ProgramLayout( + weight: GBuffer[UInt32], + input: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[UInt32]((s.totalWeightBytes + 3) / 4), + input = GBuffer[Float32](s.batchSize * s.inFeatures), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + forwardBody(sizes, layout) + + def forwardBody(sizes: Sizes, layout: ProgramLayout): GIO[Empty] = + val inFeatures = sizes.inFeatures + val outFeatures = sizes.outFeatures + val numQBlocks = sizes.numQBlocks + val weightOffsetBytes = sizes.weightOffsetBytes + val bytesPerRow = sizes.bytesPerRow + + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpInWorkgroup = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val outFeaturesVal: Int32 = outFeatures + val numQBlocksVal: Int32 = numQBlocks + val totalOutputsVal: Int32 = sizes.totalOutputs + val weightOffsetBytesVal: Int32 = weightOffsetBytes + val bytesPerRowVal: Int32 = bytesPerRow + + val outputIdx = workgroupId * NUM_ROWS + warpInWorkgroup + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + val partialSum = GIO.foldRepeat[Float32](numQBlocks, 0.0f): (qBlock, acc) => + val blockByteBase = weightOffsetBytesVal + outIdx * bytesPerRowVal + qBlock * BLOCK_BYTES + val colBase = qBlock * QK_K + + val dWord = GIO.read[UInt32](layout.weight, (blockByteBase + 208) / 4) + val dBytePos = (blockByteBase + 208).mod(4) + val mask16: UInt32 = 0xFFFF + val dHalf = (dWord >> (dBytePos * 8).unsigned) & mask16 + val d = fp16ToFp32(dHalf) + + val iter0 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId) + val iter1 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 32) + val iter2 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 64) + val iter3 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 96) + val iter4 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 128) + val iter5 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 160) + val iter6 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 192) + val iter7 = processQ6KElementGeneric(layout.weight, layout.input, blockByteBase, colBase, batch, inFeaturesVal, d, laneId + 224) + + for _ <- GIO.barrier yield + acc + iter0 + iter1 + iter2 + iter3 + iter4 + iter5 + iter6 + iter7 + + // IMPORTANT: subgroupAdd must be called by ALL lanes (it's a collective operation) + partialSum.flatMap: ps => + val totalSum: Float32 = GIO.subgroupAdd(ps) + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float32](layout.output, outputIdx, totalSum) + + private def readByteGeneric(weight: GBuffer[UInt32], byteOffset: Int32): UInt32 = + val uint32Idx = byteOffset / 4 + val byteInUint32 = byteOffset.mod(4) + val word = GIO.read[UInt32](weight, uint32Idx) + val byteMask: UInt32 = 0xFF + (word >> (byteInUint32 * 8).unsigned) & byteMask + + private def readSignedByteGeneric(weight: GBuffer[UInt32], byteOffset: Int32): Int32 = + val b = readByteGeneric(weight, byteOffset) + val mask127: UInt32 = 127 + when(b > mask127)(b.signed - 256).otherwise(b.signed) + + /** Process a single Q6_K element (matches llama.cpp dequant_q6_k.comp). */ + private def processQ6KElementGeneric( + weight: GBuffer[UInt32], + input: GBuffer[Float32], + blockByteBase: Int32, + colBase: Int32, + batch: Int32, + inFeaturesVal: Int32, + d: Float32, + localIdx: Int32, + ): Float32 = + val valid = localIdx < QK_K + + // Determine which half (ip) and position within the 128-value range + val ip = localIdx / 128 // 0 or 1 + val posIn128 = localIdx.mod(128) // 0-127 + + // Each 128-value half is organized as 4 groups of 32 + val group = posIn128 / 32 // 0, 1, 2, or 3 + val il = posIn128.mod(32) // 0-31 (position within group) + + // Compute ql_idx and qh_idx + val ql_idx_base = ip * 64 + il + val qh_idx = ip * 32 + il + + // Read the qh byte + val qhByteOffset = blockByteBase + 128 + qh_idx + val qhByte = readByteGeneric(weight, qhByteOffset) + + // Determine which ql byte and nibble based on group + val ql_idx = when(group === 0 || group === 2)(ql_idx_base).otherwise(ql_idx_base + 32) + val useHighNibble = group >= 2 + val qhShift = group * 2 + + val qlByteOffset = blockByteBase + ql_idx + val qlByte = readByteGeneric(weight, qlByteOffset) + val nibbleMask: UInt32 = 0x0F + val ql = when(useHighNibble)((qlByte >> 4.unsigned) & nibbleMask).otherwise(qlByte & nibbleMask) + + val qhMask: UInt32 = 0x03 + val qh = (qhByte >> qhShift.unsigned) & qhMask + + // Combine to get 6-bit value and subtract 32 for signed + val q6 = ql | (qh << 4.unsigned) + val qSigned = q6.signed - 32 + + // Compute scale index + val is_base = ip * 8 + il / 16 + val scaleOffset = group * 2 + val scaleIdx = is_base + scaleOffset + val scaleByteOffset = blockByteBase + 192 + scaleIdx + val scale = readSignedByteGeneric(weight, scaleByteOffset) + + val dequantized = d * scale.asFloat * qSigned.asFloat + + val col = colBase + localIdx + val inputVal = GIO.read[Float32](input, batch * inFeaturesVal + col) + + when(valid)(dequantized * inputVal).otherwise(0.0f) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RMSNormProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RMSNormProgram.scala new file mode 100644 index 00000000..9ecc0bbf --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RMSNormProgram.scala @@ -0,0 +1,88 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct.Empty + +/** RMSNorm (Root Mean Square Normalization) using subgroup reductions. + * + * RMSNorm is simpler than LayerNorm - it doesn't subtract mean: + * output = x * rsqrt(mean(x^2) + eps) * weight + * + * Uses hardware subgroup operations for efficient parallel reduction. + * One subgroup (32 threads) processes one row in strided fashion. + * + * Supports layered weights via weightOffset (defaults to 0 for standalone use). + */ +object RMSNormProgram: + + val WARP_SIZE = 32 + + case class Sizes( + numRows: Int, + rowSize: Int, + eps: Float, + weightOffset: Int = 0, + totalWeightSize: Int = -1, // -1 means use rowSize (single layer) + ): + def numIterations: Int = (rowSize + WARP_SIZE - 1) / WARP_SIZE + def actualWeightSize: Int = if totalWeightSize < 0 then rowSize else totalWeightSize + + case class ProgramLayout( + input: GBuffer[Float32], + weight: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val rowSize = sizes.rowSize + val eps = sizes.eps + val numIterations = sizes.numIterations + val weightOffset = sizes.weightOffset + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float32](s.numRows * s.rowSize), + weight = GBuffer[Float32](s.actualWeightSize), + output = GBuffer[Float32](s.numRows * s.rowSize), + ), + dispatch = (_, s) => + val warpsPerWorkgroup = 8 + val numWorkgroups = (s.numRows + warpsPerWorkgroup - 1) / warpsPerWorkgroup + StaticDispatch((numWorkgroups, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val globalId = GIO.invocationId + val laneId = globalId.mod(WARP_SIZE) + val subgroupIdx = globalId / WARP_SIZE + val rowSizeVal: Int32 = rowSize + val epsVal: Float32 = eps + val weightOffsetVal: Int32 = weightOffset + val numRowsVal: Int32 = sizes.numRows + + GIO.when(subgroupIdx < numRowsVal): + val rowIdx = subgroupIdx + val baseIdx = rowIdx * rowSizeVal + + val localSumSq = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numIterations) + .fold(0.0f, (sum: Float32, i: Int32) => + when(i < rowSizeVal): + val x = GIO.read[Float32](layout.input, baseIdx + i) + sum + x * x + .otherwise(sum) + ) + + val totalSumSq = GIO.subgroupAdd(localSumSq) + val meanSq = totalSumSq / rowSizeVal.asFloat + val scale: Float32 = 1.0f / sqrt(meanSq + epsVal) + + GIO.repeat(numIterations): j => + val i = laneId + j * WARP_SIZE + GIO.when(i < rowSizeVal): + val x = GIO.read[Float32](layout.input, baseIdx + i) + val w = GIO.read[Float32](layout.weight, weightOffsetVal + i) + GIO.write[Float32](layout.output, baseIdx + i, x * scale * w) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/ResidualAddProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/ResidualAddProgram.scala new file mode 100644 index 00000000..1bee0451 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/ResidualAddProgram.scala @@ -0,0 +1,67 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} + +/** Residual add program: output = a + b */ +object ResidualAddProgram: + val BLOCK_SIZE = 512 + + case class Sizes(size: Int) + + case class ProgramLayout( + a: GBuffer[Float32], + b: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val size = sizes.size + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + a = GBuffer[Float32](s.size), + b = GBuffer[Float32](s.size), + output = GBuffer[Float32](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val sizeVal: Int32 = size + + GIO.when(idx < sizeVal): + val aVal = GIO.read[Float32](layout.a, idx) + val bVal = GIO.read[Float32](layout.b, idx) + GIO.write[Float32](layout.output, idx, aVal + bVal) + +/** Copy program: output = input */ +object CopyProgram: + val BLOCK_SIZE = 512 + + case class Sizes(size: Int) + + case class ProgramLayout( + input: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val size = sizes.size + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float32](s.size), + output = GBuffer[Float32](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val sizeVal: Int32 = size + + GIO.when(idx < sizeVal): + val v = GIO.read[Float32](layout.input, idx) + GIO.write[Float32](layout.output, idx, v) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RoPEProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RoPEProgram.scala new file mode 100644 index 00000000..6e8be6a8 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/RoPEProgram.scala @@ -0,0 +1,94 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Rotary Position Embeddings (RoPE). + * + * RoPE applies a rotation to pairs of elements based on their position: + * x_rot[2i] = x[2i] * cos(pos * theta^(-2i/d)) - x[2i+1] * sin(pos * theta^(-2i/d)) + * x_rot[2i+1] = x[2i] * sin(pos * theta^(-2i/d)) + x[2i+1] * cos(pos * theta^(-2i/d)) + * + * Where theta is typically 10000 (or 500000 for Llama 3). + */ +object RoPEProgram: + val BLOCK_SIZE = 256 + + case class Sizes( + B: Int, + T: Int, + numHeads: Int, + headSize: Int, + theta: Float, + startPos: Int = 0, // Compile-time default, overridden by uniform at runtime + ): + def totalElements: Int = B * T * numHeads * headSize + def totalPairs: Int = B * T * numHeads * (headSize / 2) + + case class ProgramLayout( + input: GBuffer[Float32], + output: GBuffer[Float32], + params: GUniform[AttentionParams], // Runtime startPos + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val numHeads = sizes.numHeads + val headSize = sizes.headSize + val theta = sizes.theta + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float32](s.totalElements), + output = GBuffer[Float32](s.totalElements), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch(((s.totalPairs + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val totalPairsVal: Int32 = B * T * numHeads * (headSize / 2) + val Tval: Int32 = T + val numHeadsVal: Int32 = numHeads + val halfHead: Int32 = headSize / 2 + val thetaVal: Float32 = theta + // Read startPos from runtime uniform + val startPosVal: Int32 = layout.params.read.startPos + + GIO.when(idx < totalPairsVal): + val perHead = halfHead + val perPos = numHeadsVal * halfHead + val perBatch = Tval * perPos + + val b = idx / perBatch + val rem1 = idx.mod(perBatch) + val t = rem1 / perPos + val rem2 = rem1.mod(perPos) + val h = rem2 / perHead + val d = rem2.mod(perHead) + + val pos = startPosVal + t + val headSizeFloat: Float32 = headSize.toFloat + val freqExponent: Float32 = -2.0f * d.asFloat / headSizeFloat + val freq: Float32 = pos.asFloat * pow(thetaVal, freqExponent) + val cosFreq = cos(freq) + val sinFreq = sin(freq) + + val fullIdx: Int32 = b * Tval * numHeadsVal * headSize + t * numHeadsVal * headSize + h * headSize + val idx0 = fullIdx + d * 2 + val idx1 = idx0 + 1 + + val x0 = GIO.read[Float32](layout.input, idx0) + val x1 = GIO.read[Float32](layout.input, idx1) + val y0 = x0 * cosFreq - x1 * sinFreq + val y1 = x0 * sinFreq + x1 * cosFreq + + for + _ <- GIO.write[Float32](layout.output, idx0, y0) + _ <- GIO.write[Float32](layout.output, idx1, y1) + yield Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/SwiGLUProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/SwiGLUProgram.scala new file mode 100644 index 00000000..c81f0f77 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/SwiGLUProgram.scala @@ -0,0 +1,77 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} + +/** SiLU (Sigmoid Linear Unit) activation function. + * + * SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)) + * + * Also known as Swish. Used in Llama's MLP layers. + */ +object SiLUProgram: + val BLOCK_SIZE = 512 + + case class Sizes(size: Int) + + case class ProgramLayout( + input: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val size = sizes.size + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float32](s.size), + output = GBuffer[Float32](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val sizeVal: Int32 = size + + GIO.when(idx < sizeVal): + val x = GIO.read[Float32](layout.input, idx) + val result = x / (1.0f + exp(-x)) + GIO.write[Float32](layout.output, idx, result) + +/** SwiGLU activation used in Llama MLP. + * + * SwiGLU(gate, up) = SiLU(gate) * up + */ +object SwiGLUProgram: + val BLOCK_SIZE = 512 + + case class Sizes(size: Int) + + case class ProgramLayout( + gate: GBuffer[Float32], + up: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val size = sizes.size + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + gate = GBuffer[Float32](s.size), + up = GBuffer[Float32](s.size), + output = GBuffer[Float32](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val sizeVal: Int32 = size + + GIO.when(idx < sizeVal): + val g = GIO.read[Float32](layout.gate, idx) + val u = GIO.read[Float32](layout.up, idx) + val silu_g = g / (1.0f + exp(-g)) + GIO.write[Float32](layout.output, idx, silu_g * u) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/TiledMatmulVecProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/TiledMatmulVecProgram.scala new file mode 100644 index 00000000..a0723567 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/TiledMatmulVecProgram.scala @@ -0,0 +1,85 @@ +package io.computenode.cyfra.llama.programs.f32 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct.Empty + +/** Matrix-Vector Multiplication with subgroup reduction. + * + * One warp (32 threads) computes one output element using subgroup operations. + * Multiple warps per workgroup for better occupancy. + * + * Performance optimizations: + * 1. Subgroup reduction (hardware accelerated) + * 2. Multiple warps per workgroup (8 warps = 8 outputs per workgroup) + * 3. Strided access within each warp + * + * Supports layered weights via weightOffset (defaults to 0 for standalone use). + */ +object TiledMatmulVecProgram: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP // 256 + + case class Sizes( + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + weightOffset: Int = 0, + totalWeightSize: Int = -1, // -1 means use outFeatures * inFeatures (single layer) + ): + def totalOutputs: Int = batchSize * outFeatures + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numIterations: Int = (inFeatures + WARP_SIZE - 1) / WARP_SIZE + def actualWeightSize: Int = if totalWeightSize < 0 then outFeatures * inFeatures else totalWeightSize + + case class ProgramLayout( + weight: GBuffer[Float32], + input: GBuffer[Float32], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val outFeatures = sizes.outFeatures + val weightOffset = sizes.weightOffset + val numIterations = sizes.numIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[Float32](s.actualWeightSize), + input = GBuffer[Float32](s.batchSize * s.inFeatures), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + val inFeaturesVal: Int32 = inFeatures + val outFeaturesVal: Int32 = outFeatures + val weightOffsetVal: Int32 = weightOffset + val totalOutputsVal: Int32 = sizes.totalOutputs + + val outputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + val batch = outputIdx / outFeaturesVal + val outIdx = outputIdx.mod(outFeaturesVal) + + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numIterations) + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesVal): + val w = GIO.read[Float32](layout.weight, weightOffsetVal + outIdx * inFeaturesVal + k) + val x = GIO.read[Float32](layout.input, batch * inFeaturesVal + k) + sum + w * x + .otherwise(sum) + ) + + val totalSum = GIO.subgroupAdd(localSum) + GIO.when(outputIdx < totalOutputsVal): + GIO.write[Float32](layout.output, outputIdx, totalSum) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/package.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/package.scala new file mode 100644 index 00000000..fc37404c --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f32/package.scala @@ -0,0 +1,21 @@ +package io.computenode.cyfra.llama.programs + +/** F32 GPU programs for Llama inference. + * + * Programs for transformer operations using Float32 precision: + * - EmbeddingProgram: Token embedding lookup + * - RMSNormProgram: Root mean square normalization + * - RoPEProgram: Rotary position embeddings + * - TiledMatmulVecProgram: Matrix-vector multiplication with subgroup reduction + * - ResidualAddProgram/CopyProgram: Residual connections and data copying + * - SwiGLUProgram: SwiGLU activation + * - Q4KMatmulVecProgram/Q6KMatmulVecProgram: Quantized matmul + * - KVCacheWriteK/KVCacheWriteV: KV cache write programs + * - KVCachedAttention: Cached attention computation + */ +package object f32: + /** Alias for layered Q4K matmul program */ + val Q4KMatmulLayered = Q4KMatmulVecProgram.Layered + + /** Alias for layered Q6K matmul program */ + val Q6KMatmulLayered = Q6KMatmulVecProgram.Layered diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala new file mode 100644 index 00000000..da5e9c56 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.struct.GStruct + +package object programs: + + /** Runtime parameters for attention and RoPE operations. + * + * Passed via uniform to support single compiled pipeline with runtime-varying positions. + * Used by RoPE (for position encoding) and attention (for KV cache operations). + */ + case class AttentionParams( + seqLen: Int32, // actual sequence length (startPos + T) + startPos: Int32, // position of first query token in full sequence + ) extends GStruct[AttentionParams] diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala new file mode 100644 index 00000000..3f6fbab6 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala @@ -0,0 +1,121 @@ +package io.computenode.cyfra.llama.tokenizer + +import io.computenode.cyfra.llama.gguf.GGUFReader.GGUFFile +import scala.collection.mutable + +/** Simple BPE tokenizer for Llama models. + * + * Reads vocabulary from GGUF metadata. + * Supports encoding (text -> tokens) and decoding (tokens -> text). + */ +class LlamaTokenizer(gguf: GGUFFile): + + // Special tokens + val bosToken: Int = gguf.getInt("tokenizer.ggml.bos_token_id").getOrElse(1) + val eosToken: Int = gguf.getInt("tokenizer.ggml.eos_token_id").getOrElse(2) + val padToken: Int = gguf.getInt("tokenizer.ggml.padding_token_id").getOrElse(0) + + // Load vocabulary from GGUF + private val vocab: Array[String] = gguf.getStringArray("tokenizer.ggml.tokens").getOrElse(Array.empty) + private val scores: Array[Float] = gguf.getFloatArray("tokenizer.ggml.scores").getOrElse(Array.empty) + + // Build reverse lookup for encoding + private val tokenToId: Map[String, Int] = vocab.zipWithIndex.toMap + + /** Number of tokens in vocabulary. */ + def vocabSize: Int = vocab.length + + // GPT-style special characters used by Llama 3 + private val GPT_SPACE = '\u0120' // Ġ - space marker + private val GPT_NEWLINE = '\u010A' // Ċ - newline marker + private val GPT_TAB = '\u0109' // ĉ - tab marker + + /** Decode a single token to string. + * Handles special byte tokens like <0xNN> and GPT-style markers. + */ + def decodeToken(tokenId: Int): String = + if tokenId < 0 || tokenId >= vocab.length then + s"" + else + val token = vocab(tokenId) + // Handle byte tokens like <0xNN> + if token.startsWith("<0x") && token.endsWith(">") then + try + val byteVal = Integer.parseInt(token.drop(3).dropRight(1), 16) + new String(Array(byteVal.toByte), "UTF-8") + catch + case _: Exception => token + else + // Replace GPT-style markers with actual characters + token + .replace(GPT_SPACE.toString, " ") + .replace(GPT_NEWLINE.toString, "\n") + .replace(GPT_TAB.toString, "\t") + .replace("▁", " ") // Sentencepiece space marker + + /** Decode a sequence of tokens to string. */ + def decode(tokens: Array[Int]): String = + tokens.map(decodeToken).mkString + + // Detect which space marker the vocabulary uses (▁ for Llama 1/2, Ġ for Llama 3) + private val spaceMarker: String = + if tokenToId.contains("Ġ") || vocab.exists(_.startsWith("Ġ")) then "Ġ" + else "▁" + + /** Encode text to tokens using greedy longest-match BPE. + * + * This handles both SentencePiece (▁) and GPT (Ġ) space marker conventions: + * - Space marker represents a space before the token + * - First token of a word has space marker prefix + * + * Note: This is a simplified implementation. For production, + * use the official sentencepiece tokenizer. + */ + def encode(text: String, addBos: Boolean = true): Array[Int] = + val tokens = mutable.ArrayBuffer[Int]() + + if addBos then + tokens += bosToken + + // Replace spaces with detected space marker and prepend for start + val normalized = spaceMarker + text.replace(" ", spaceMarker) + + // Greedy longest-match tokenization + var pos = 0 + while pos < normalized.length do + var found = false + var maxLen = math.min(normalized.length - pos, 64) // Max token length + + // Try to find longest matching token + while maxLen > 0 && !found do + val candidate = normalized.substring(pos, pos + maxLen) + if tokenToId.contains(candidate) then + tokens += tokenToId(candidate) + pos += maxLen + found = true + else + maxLen -= 1 + + if !found then + // Single character fallback + val char = normalized.charAt(pos) + val charStr = char.toString + if tokenToId.contains(charStr) then + tokens += tokenToId(charStr) + else + // Unknown character - try byte fallback for UTF-8 bytes + val bytes = charStr.getBytes("UTF-8") + for b <- bytes do + val byteToken = f"<0x${b & 0xFF}%02X>" + tokenToId.get(byteToken).foreach(tokens += _) + pos += 1 + + tokens.toArray + + /** Get token string by ID (for debugging). */ + def getToken(tokenId: Int): String = + if tokenId >= 0 && tokenId < vocab.length then vocab(tokenId) + else s"" + +object LlamaTokenizer: + def apply(gguf: GGUFFile): LlamaTokenizer = new LlamaTokenizer(gguf) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala new file mode 100644 index 00000000..2e90408d --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala @@ -0,0 +1,12 @@ +package io.computenode.cyfra.llama.util + +import org.slf4j.LoggerFactory + +/** Logger for the Llama module using SLF4J. */ +object Logger: + private val logger = LoggerFactory.getLogger("io.computenode.cyfra.llama") + + def info(msg: => String): Unit = if logger.isInfoEnabled then logger.info(msg) + def debug(msg: => String): Unit = if logger.isDebugEnabled then logger.debug(msg) + def warn(msg: => String): Unit = if logger.isWarnEnabled then logger.warn(msg) + def error(msg: => String): Unit = if logger.isErrorEnabled then logger.error(msg) diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala new file mode 100644 index 00000000..0dc08765 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala @@ -0,0 +1,278 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.core.GBufferRegion +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.llama.gguf.{Dequantize => CpuDequantize} +import io.computenode.cyfra.llama.programs.f32.* +import io.computenode.cyfra.runtime.VkCyfraRuntime +import munit.FunSuite + +import java.nio.{ByteBuffer, ByteOrder} + +/** Tests to verify Q4_K and Q6_K dequantization matches CPU implementation. */ +class DequantizationTest extends FunSuite: + + def allocateBuffer(floats: Int): ByteBuffer = + ByteBuffer.allocateDirect(floats * 4).order(ByteOrder.nativeOrder()) + + def copyToBuffer(arr: Array[Float], buf: ByteBuffer): Unit = + buf.clear(); buf.asFloatBuffer().put(arr); buf.rewind() + + def copyFromBuffer(buf: ByteBuffer, arr: Array[Float]): Unit = + buf.rewind(); buf.asFloatBuffer().get(arr) + + /** Create a test Q4_K block with known values. + * + * Q4_K block layout (144 bytes): + * - bytes 0-1: d (fp16) + * - bytes 2-3: dmin (fp16) + * - bytes 4-15: scales (12 bytes) + * - bytes 16-143: qs (128 bytes, 4-bit quantized values) + */ + def createQ4KBlock(d: Float, dmin: Float, scales: Array[Byte], qs: Array[Byte]): Array[Byte] = + require(scales.length == 12, "scales must be 12 bytes") + require(qs.length == 128, "qs must be 128 bytes") + val block = new Array[Byte](144) + val buf = ByteBuffer.wrap(block).order(ByteOrder.LITTLE_ENDIAN) + + // Write d and dmin as fp16 + buf.putShort(0, floatToFp16(d)) + buf.putShort(2, floatToFp16(dmin)) + + // Write scales + System.arraycopy(scales, 0, block, 4, 12) + + // Write qs + System.arraycopy(qs, 0, block, 16, 128) + + block + + /** Create a test Q6_K block with known values. + * + * Q6_K block layout (210 bytes): + * - bytes 0-127: ql (low 4 bits) + * - bytes 128-191: qh (high 2 bits) + * - bytes 192-207: scales (int8) + * - bytes 208-209: d (fp16) + */ + def createQ6KBlock(d: Float, scales: Array[Byte], ql: Array[Byte], qh: Array[Byte]): Array[Byte] = + require(scales.length == 16, "scales must be 16 bytes") + require(ql.length == 128, "ql must be 128 bytes") + require(qh.length == 64, "qh must be 64 bytes") + val block = new Array[Byte](210) + val buf = ByteBuffer.wrap(block).order(ByteOrder.LITTLE_ENDIAN) + + // Write ql + System.arraycopy(ql, 0, block, 0, 128) + + // Write qh + System.arraycopy(qh, 0, block, 128, 64) + + // Write scales + System.arraycopy(scales, 0, block, 192, 16) + + // Write d as fp16 + buf.putShort(208, floatToFp16(d)) + + block + + /** Convert float32 to fp16 (approximate, for test purposes). */ + def floatToFp16(f: Float): Short = + val bits = java.lang.Float.floatToIntBits(f) + val sign = (bits >> 31) & 1 + val exp = (bits >> 23) & 0xFF + val mant = bits & 0x7FFFFF + + if exp == 0 then + // Zero or denormalized + (sign << 15).toShort + else if exp == 255 then + // Infinity or NaN + ((sign << 15) | 0x7C00).toShort + else + val newExp = exp - 127 + 15 + if newExp <= 0 then + // Underflow to zero + (sign << 15).toShort + else if newExp >= 31 then + // Overflow to infinity + ((sign << 15) | 0x7C00).toShort + else + val newMant = mant >> 13 + ((sign << 15) | (newExp << 10) | newMant).toShort + + test("Q4_K dequantization: GPU matches CPU for all 8 scale indices"): + VkCyfraRuntime.using: + val K = 256 // One Q4_K block + val N = 1 // One output row + + // Create Q4_K block with distinctive scale patterns + // Scales are 6-bit packed: for j < 4, scale is scales[j] & 0x3F + // For j >= 4, it's a combination of bits + val scales = Array.fill[Byte](12)(0) + + // Set scales for j=0..3 (simple 6-bit in lower bytes) + scales(0) = 1 // scale 0 = 1 + scales(1) = 2 // scale 1 = 2 + scales(2) = 3 // scale 2 = 3 + scales(3) = 4 // scale 3 = 4 + + // Set mins for j=0..3 (in bytes 4-7) + scales(4) = 0 // min 0 = 0 + scales(5) = 0 // min 1 = 0 + scales(6) = 0 // min 2 = 0 + scales(7) = 0 // min 3 = 0 + + // Scales for j=4..7 are combinations - set them to known patterns + // For j=4: sc = (scales[8] & 0x0F) | ((scales[0] >> 6) << 4) + // Let's set scales[8] = 0x05 and ensure scales[0] high bits are 0 + scales(8) = 5 // scale 4 = 5 + scales(9) = 6 // scale 5 = 6 + scales(10) = 7 // scale 6 = 7 + scales(11) = 8 // scale 7 = 8 + + // Create qs with distinctive pattern: q values 0-15 for each position + val qs = Array.tabulate[Byte](128)(i => ((i % 16) | ((i % 16) << 4)).toByte) + + val d = 1.0f + val dmin = 0.0f + val block = createQ4KBlock(d, dmin, scales, qs) + + // CPU dequantization + val cpuResult = CpuDequantize.dequantizeQ4K(block, 256) + + // GPU computation + val numUint32 = 144 / 4 + val weightBuf = ByteBuffer.allocateDirect(numUint32 * 4).order(ByteOrder.nativeOrder()) + weightBuf.put(block); weightBuf.rewind() + + val input = Array.fill(K)(1.0f) + val inputBuf = allocateBuffer(K); copyToBuffer(input, inputBuf) + val outputBuf = allocateBuffer(N) + + val sizes = Q4KMatmulVecProgram.Sizes(1, K, N) + val program = Q4KMatmulVecProgram.forward(sizes) + + val region = GBufferRegion + .allocate[Q4KMatmulVecProgram.ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + val gpuResult = new Array[Float](N) + region.runUnsafe( + init = Q4KMatmulVecProgram.ProgramLayout( + weight = GBuffer[UInt32](weightBuf), + input = GBuffer[Float32](inputBuf), + output = GBuffer[Float32](outputBuf), + ), + onDone = layout => + layout.output.read(outputBuf) + copyFromBuffer(outputBuf, gpuResult), + ) + + // The GPU result should be the dot product of dequantized weights with all-ones input + // which equals the sum of all dequantized values + val cpuSum = cpuResult.sum + val gpuSum = gpuResult(0) + + println(s"Q4_K CPU sum: $cpuSum") + println(s"Q4_K GPU sum: $gpuSum") + println(s"First 32 CPU values: ${cpuResult.take(32).mkString(", ")}") + + // Allow small floating-point tolerance + assertEqualsDouble(gpuSum.toDouble, cpuSum.toDouble, 1.0, + s"Q4_K GPU sum ($gpuSum) should match CPU sum ($cpuSum)") + + test("Q6_K dequantization: GPU matches CPU"): + VkCyfraRuntime.using: + val K = 256 // One Q6_K block + val N = 1 // One output row + + // Create Q6_K block with known values + val d = 1.0f + val scales = Array.tabulate[Byte](16)(i => (i + 1).toByte) // scales 1-16 + + // ql: low 4 bits of each 6-bit value + // qh: high 2 bits of each 6-bit value + // For simplicity, set all q values to a known pattern + val ql = Array.fill[Byte](128)(0x55.toByte) // alternating 0101 pattern + val qh = Array.fill[Byte](64)(0x00.toByte) // high bits all 0 + + val block = createQ6KBlock(d, scales, ql, qh) + + // CPU dequantization + val cpuResult = CpuDequantize.dequantizeQ6K(block, 256) + + // GPU computation + val numBytes = 210 + val numUint32 = (numBytes + 3) / 4 + val weightBuf = ByteBuffer.allocateDirect(numUint32 * 4).order(ByteOrder.nativeOrder()) + weightBuf.put(block); weightBuf.rewind() + + val input = Array.fill(K)(1.0f) + val inputBuf = allocateBuffer(K); copyToBuffer(input, inputBuf) + val outputBuf = allocateBuffer(N) + + val sizes = Q6KMatmulVecProgram.Sizes(1, K, N) + val program = Q6KMatmulVecProgram.forward(sizes) + + val region = GBufferRegion + .allocate[Q6KMatmulVecProgram.ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + val gpuResult = new Array[Float](N) + region.runUnsafe( + init = Q6KMatmulVecProgram.ProgramLayout( + weight = GBuffer[UInt32](weightBuf), + input = GBuffer[Float32](inputBuf), + output = GBuffer[Float32](outputBuf), + ), + onDone = layout => + layout.output.read(outputBuf) + copyFromBuffer(outputBuf, gpuResult), + ) + + val cpuSum = cpuResult.sum + val gpuSum = gpuResult(0) + + println(s"Q6_K CPU sum: $cpuSum") + println(s"Q6_K GPU sum: $gpuSum") + println(s"First 32 CPU values: ${cpuResult.take(32).mkString(", ")}") + + assertEqualsDouble(gpuSum.toDouble, cpuSum.toDouble, 1.0, + s"Q6_K GPU sum ($gpuSum) should match CPU sum ($cpuSum)") + + test("Q4_K scale extraction: verify is < 4 vs is >= 4 logic"): + // This test verifies the scale extraction matches llama.cpp's get_scale_min_k4 + // For is < 4: sc = scales[is] & 0x3F, m = scales[is+4] & 0x3F + // For is >= 4: sc = (scales[is+4] & 0x0F) | ((scales[is-4] >> 6) << 4) + // m = ((scales[is+4] >> 4) & 0x0F) | ((scales[is] >> 6) << 4) + + val scales = Array[Byte]( + 0x3F, 0x3E, 0x3D, 0x3C, // scales[0-3]: values 63, 62, 61, 60 + 0x10, 0x20, 0x30, 0x40.toByte, // scales[4-7]: mins for j<4 + 0x05, 0x06, 0x07, 0x08 // scales[8-11]: for j>=4 extraction + ) + + // Verify CPU extraction for j=0 + val (sc0, m0) = Dequantize.getScaleMinK4(0, scales) + assertEquals(sc0.toInt, 63, "scale for j=0 should be 63") + assertEquals(m0.toInt, 16, "min for j=0 should be 16") + + // Verify CPU extraction for j=4 + // sc4 = (scales[8] & 0x0F) | ((scales[0] >> 6) << 4) + // scales[8] = 0x05, scales[0] = 0x3F -> 0x3F >> 6 = 0 + // sc4 = (5 & 0x0F) | (0 << 4) = 5 + val (sc4, m4) = Dequantize.getScaleMinK4(4, scales) + assertEquals(sc4.toInt, 5, "scale for j=4 should be 5") + + // Helper method to expose getScaleMinK4 for testing + object Dequantize: + def getScaleMinK4(j: Int, scales: Array[Byte]): (Float, Float) = + if j < 4 then + val d = (scales(j) & 0x3F).toFloat + val m = (scales(j + 4) & 0x3F).toFloat + (d, m) + else + val d = ((scales(j + 4) & 0x0F) | ((scales(j - 4) >> 6) << 4)).toFloat + val m = ((scales(j + 4) >> 4) & 0x0F | ((scales(j) >> 6) << 4)).toFloat + (d, m) diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala new file mode 100644 index 00000000..28e13bdd --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala @@ -0,0 +1,103 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.gguf.GGUFReader +import io.computenode.cyfra.llama.inference.LlamaInference +import io.computenode.cyfra.llama.tokenizer.LlamaTokenizer +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.pipeline.LlamaF32Pipeline +import io.computenode.cyfra.runtime.VkCyfraRuntime +import munit.FunSuite + +import java.nio.file.{Files, Paths} +import scala.concurrent.duration.* + +/** Direct benchmark to verify which code path is actually running. */ +class DirectBenchmarkTest extends FunSuite: + + val modelPath = "cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + + override def munitTimeout: Duration = 10.minutes + + test("Direct KVCachedPipeline.generate benchmark"): + assume(Files.exists(Paths.get(modelPath)), s"Model not found: $modelPath") + + VkCyfraRuntime.using: + println("Loading model...") + val model = LlamaModel.fromGGUF(Paths.get(modelPath)) + + try + println("Creating inference...") + val inference = new LlamaInference(model, maxT = 2048, useQuantized = true) + val kvPipeline = inference.getF32KVCachedPipeline + + // Simple argmax sampling (greedy, deterministic) + def argmax(logits: Array[Float]): Int = + var maxIdx = 0 + var maxVal = logits(0) + var i = 1 + while i < logits.length do + if logits(i) > maxVal then + maxVal = logits(i) + maxIdx = i + i += 1 + maxIdx + + val tokenizer = LlamaTokenizer(model.gguf) + val promptText = "Once upon a time" + val promptTokens = tokenizer.encode(promptText) + + println("\n" + "=" * 60) + println(" TinyLlama 1.1B - KV Cache Benchmark (Cyfra GPU)") + println("=" * 60) + + // Warmup - 3 generations to ensure everything is compiled and cached + println("\nWarming up (3 generations)...") + for i <- 1 to 3 do + kvPipeline.generate(promptTokens, 20, argmax) + println(s" warmup $i done") + + // Benchmark with longer generation + val maxTokens = 128 + println(s"\n--- Benchmark: $maxTokens tokens ---") + println(s"Prompt: '$promptText'\n") + + // Timed generation with output + print("Output: ") + val start = System.nanoTime() + val generated = kvPipeline.generate( + promptTokens = promptTokens, + maxNewTokens = maxTokens, + sampleFn = argmax, + onToken = token => print(tokenizer.decodeToken(token)), + stopTokens = Set(2), + ) + val elapsed = (System.nanoTime() - start) / 1e6 + println("\n") + + val tokPerSec = generated.length * 1000.0 / elapsed + println(s"Generated: ${generated.length} tokens") + println(s"Time: ${elapsed.toInt} ms") + println(f"Throughput: $tokPerSec%.1f tok/s") + + // Multiple runs for consistency + println(s"\n--- Consistency check (5 runs x $maxTokens tokens) ---") + val times = (1 to 5).map: i => + val runStart = System.nanoTime() + val tokens = kvPipeline.generate(promptTokens, maxTokens, argmax) + val runElapsed = (System.nanoTime() - runStart) / 1e6 + val runTokPerSec = tokens.length * 1000.0 / runElapsed + println(f" Run $i: ${tokens.length} tokens in ${runElapsed.toInt}%5d ms = $runTokPerSec%.1f tok/s") + runElapsed + + val avgTime = times.sum / times.length + val avgTokPerSec = maxTokens * 1000.0 / avgTime + val minTime = times.min + val maxTokPerSec = maxTokens * 1000.0 / minTime + + println("\n" + "=" * 60) + println(f" Average: $avgTokPerSec%.1f tok/s") + println(f" Best: $maxTokPerSec%.1f tok/s") + println("=" * 60) + + finally + model.close() diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala new file mode 100644 index 00000000..ef417b5a --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala @@ -0,0 +1,111 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.inference.LlamaInference +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.tokenizer.LlamaTokenizer +import io.computenode.cyfra.runtime.VkCyfraRuntime +import munit.FunSuite + +import java.nio.file.{Files, Paths} +import scala.concurrent.duration.* + +/** Tests for F16 KV Cache Pipeline with Vec4-optimized matmuls. + * + * This tests the F16KVCachedPipeline which provides O(1) per-token inference + * by maintaining an F16 KV cache on GPU. Uses Vec4 weight loads for 4x bandwidth. + */ +class F16KVCacheTest extends FunSuite: + + val modelPath = Paths.get("cyfra-llama/Llama-3.2-1B-Instruct-f16.gguf") + + override def munitTimeout: Duration = 15.minutes + + test("F16 KV Cache Pipeline - longer generation benchmark"): + assume(Files.exists(modelPath), s"Model not found: $modelPath") + + VkCyfraRuntime.using: + println("\n" + "=" * 70) + println(" F16 KV Cache Pipeline - Performance Benchmark") + println("=" * 70) + + val model = LlamaModel.fromGGUF(modelPath) + val tokenizer = LlamaTokenizer(model.gguf) + + try + val inference = new LlamaInference(model, maxT = 2048) + val f16KVPipeline = inference.getF16KVCachedPipeline + + // Temperature sampling for more varied output + val random = new scala.util.Random(42) + def sampleWithTemperature(logits: Array[Float], temperature: Float = 0.1f): Int = + val scaled = logits.map(_ / temperature) + val maxLogit = scaled.max + val exps = scaled.map(x => math.exp(x - maxLogit).toFloat) + val sumExps = exps.sum + val probs = exps.map(_ / sumExps) + val r = random.nextFloat() + var cumSum = 0.0f + var i = 0 + while i < probs.length && cumSum < r do + cumSum += probs(i) + i += 1 + math.max(0, i - 1) + + def sample(logits: Array[Float]): Int = sampleWithTemperature(logits, 0.2f) + + val promptText = "Here is a Python server that creates a new user in the database and the repository:" + val promptTokens = tokenizer.encode(promptText) + val maxTokens = 1000 + + println(s"\nPrompt: '$promptText'") + println(s"Generating $maxTokens tokens...") + + // Warmup + println("Warming up (1 generation)...") + f16KVPipeline.generate(promptTokens, 100, sample) + + + // Benchmark + println(s"\n--- Benchmark: $maxTokens tokens ---\n") + print("Output: " + promptText) + val wallStart = System.nanoTime() + val generated = f16KVPipeline.generate( + promptTokens = promptTokens, + maxNewTokens = maxTokens, + sampleFn = sample, + onToken = token => print(tokenizer.decodeToken(token)), + stopTokens = Set(tokenizer.eosToken, 128009), // EOS + end-of-turn + reportStats = true, // Print GPU execution timing + ) + val wallElapsed = (System.nanoTime() - wallStart) / 1e6 + println("\n") + + val wallTokPerSec = generated.length * 1000.0 / wallElapsed + println(f"Wall-clock time: ${wallElapsed.toInt} ms ($wallTokPerSec%.2f tok/s including sampling)") + + // Multiple runs - GPU-only timing + println(s"\n--- Consistency check (5 runs x $maxTokens tokens) ---") + println("(Reporting GPU-only decode throughput, excludes sampling)") + val gpuTimes = (1 to 5).map: i => + f16KVPipeline.generate( + promptTokens = promptTokens, + maxNewTokens = maxTokens, + sampleFn = sample, + stopTokens = Set(tokenizer.eosToken, 128009), + ) + val stats = f16KVPipeline.lastStats.get + println(f" Run $i: ${stats.generatedTokens} tokens, decode ${stats.decodeTimeMs.toInt}%5d ms = ${stats.decodeTokPerSec}%.2f tok/s (prefill ${stats.prefillTimeMs.toInt} ms)") + stats.decodeTokPerSec + + val avgDecodeTokPerSec = gpuTimes.sum / gpuTimes.length + + println("\n" + "=" * 70) + println(f" Average GPU decode throughput: $avgDecodeTokPerSec%.2f tok/s") + println(" (This is pure GPU time, excludes CPU sampling overhead)") + println(" Memory usage: ~50% of F32 (F16 weights + F16 KV cache)") + println("=" * 70) + + finally + model.close() + +end F16KVCacheTest diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala new file mode 100644 index 00000000..456597a0 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala @@ -0,0 +1,90 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.gguf.GGUFReader +import munit.FunSuite + +import java.nio.file.{Files, Path, Paths} + +class GGUFTest extends FunSuite: + // Set this to the path of a GGUF model file for testing + val testModelPath: Path = Paths.get( + sys.env.getOrElse("LLAMA_MODEL_PATH", "models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf") + ) + + test("parse GGUF header and metadata"): + assume(Files.exists(testModelPath), s"Model file not found at $testModelPath") + + val gguf = GGUFReader.read(testModelPath) + try + println(s"GGUF Version: ${gguf.version}") + println(s"Tensors: ${gguf.tensors.size}") + println(s"Data offset: ${gguf.dataOffset}") + println() + + println("Metadata keys:") + gguf.metadata.keys.toSeq.sorted.foreach(k => println(s" $k")) + println() + + println("Architecture:") + gguf.getString("general.architecture").foreach(v => println(s" $v")) + + println("\nModel parameters:") + val arch = gguf.getString("general.architecture").getOrElse("llama") + gguf.getInt(s"$arch.embedding_length").foreach(v => println(s" embedding_length: $v")) + gguf.getInt(s"$arch.feed_forward_length").foreach(v => println(s" feed_forward_length: $v")) + gguf.getInt(s"$arch.attention.head_count").foreach(v => println(s" head_count: $v")) + gguf.getInt(s"$arch.attention.head_count_kv").foreach(v => println(s" head_count_kv: $v")) + gguf.getInt(s"$arch.block_count").foreach(v => println(s" block_count: $v")) + gguf.getInt(s"$arch.vocab_size").foreach(v => println(s" vocab_size: $v")) + gguf.getInt(s"$arch.context_length").foreach(v => println(s" context_length: $v")) + println() + + println("Tensors (first 30):") + gguf.tensors.take(30).foreach: t => + println(s" ${t.name}: shape=${t.shape.mkString("x")}, type=${t.quantType}, offset=${t.offset}") + + assert(gguf.tensors.nonEmpty) + finally + gguf.close() + + test("load LlamaModel from GGUF"): + assume(Files.exists(testModelPath), s"Model file not found at $testModelPath") + + val model = LlamaModel.fromGGUF(testModelPath) + try + model.logInfo() + + // Verify config was extracted correctly + println(s"\nExtracted config:") + println(s" hiddenSize: ${model.config.hiddenSize}") + println(s" intermediateSize: ${model.config.intermediateSize}") + println(s" numAttentionHeads: ${model.config.numAttentionHeads}") + println(s" numKeyValueHeads: ${model.config.numKeyValueHeads}") + println(s" numHiddenLayers: ${model.config.numHiddenLayers}") + println(s" vocabSize: ${model.config.vocabSize}") + println(s" maxPositionEmbeddings: ${model.config.maxPositionEmbeddings}") + println(s" headSize: ${model.config.headSize}") + println(s" gqaRatio: ${model.config.gqaRatio}") + println(s" ropeTheta: ${model.config.ropeTheta}") + + // Verify expected tensor names exist + val expectedTensors = Seq( + LlamaModel.TensorNames.tokenEmbed, + LlamaModel.TensorNames.outputNorm, + LlamaModel.TensorNames.attnNorm(0), + LlamaModel.TensorNames.attnQ(0), + LlamaModel.TensorNames.ffnNorm(0), + LlamaModel.TensorNames.ffnGate(0), + ) + + println(s"\nChecking expected tensors:") + expectedTensors.foreach: name => + model.getTensor(name) match + case Some(t) => println(s" ✓ $name: ${t.shape.mkString("x")} (${t.quantType})") + case None => println(s" ✗ $name: NOT FOUND") + + assert(model.config.hiddenSize > 0) + assert(model.config.numHiddenLayers > 0) + finally + model.close() diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala new file mode 100644 index 00000000..dddea089 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala @@ -0,0 +1,217 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.core.{GProgram, GioProgram} +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.spirv.compilers.DSLCompiler +import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvTool, SpirvToolsRunner, SpirvValidator} +import io.computenode.cyfra.llama.programs.f32.* +import io.computenode.cyfra.llama.programs.f16.* +import munit.FunSuite + +import java.nio.ByteBuffer +import java.nio.file.{Files, Path, Paths} + +/** Dumps GPU program shaders to SPIR-V assembly and GLSL for inspection. + * + * This is useful for: + * - Comparing generated code with llama.cpp shaders + * - Verifying optimizations (Vec4 loads, unrolling, subgroup ops) + * - Debugging shader compilation issues + */ +class ShaderDumpTest extends FunSuite: + + val outputDir = Paths.get("cyfra-llama/output/shaders") + + override def beforeAll(): Unit = + Files.createDirectories(outputDir) + + // ============ F32 Programs ============ + + test("Dump F32 TiledMatmulVec shader"): + val program = TiledMatmulVecProgram.forward(TiledMatmulVecProgram.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048 + )) + dumpProgram("f32_tiled_matmul_vec", program) + + test("Dump F32 Q4KMatmulVec shader"): + val program = Q4KMatmulVecProgram.forward(Q4KMatmulVecProgram.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048 + )) + dumpProgram("f32_q4k_matmul_vec", program) + + test("Dump F32 Q6KMatmulVec shader"): + val program = Q6KMatmulVecProgram.forward(Q6KMatmulVecProgram.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048 + )) + dumpProgram("f32_q6k_matmul_vec", program) + + test("Dump F32 RMSNorm shader"): + val program = RMSNormProgram.forward(RMSNormProgram.Sizes( + numRows = 1, + rowSize = 2048, + eps = 1e-6f + )) + dumpProgram("f32_rmsnorm", program) + + test("Dump F32 RoPE shader"): + val program = RoPEProgram.forward(RoPEProgram.Sizes( + B = 1, + T = 2, + numHeads = 32, + headSize = 64, + theta = 10000f, + startPos = 0 + )) + dumpProgram("f32_rope", program) + + test("Dump F32 SwiGLU shader"): + val program = SwiGLUProgram.forward(SwiGLUProgram.Sizes(5632)) + dumpProgram("f32_swiglu", program) + + test("Dump F32 ResidualAdd shader"): + val program = ResidualAddProgram.forward(ResidualAddProgram.Sizes(2048)) + dumpProgram("f32_residual_add", program) + + test("Dump F32 Embedding shader"): + val program = EmbeddingProgram.forward(EmbeddingProgram.Sizes( + seqLen = 2, + hiddenSize = 2048, + vocabSize = 32000 + )) + dumpProgram("f32_embedding", program) + + test("Dump F32 Q4K Matmul Layered shader"): + val program = Q4KMatmulVecProgram.Layered.forward(Q4KMatmulVecProgram.Layered.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048, + weightOffsetUint32 = 0, + totalWeightUint32 = 2048 * 8 * 36 // 8 blocks per row + )) + dumpProgram("f32_q4k_matmul_vec_layered", program) + + test("Dump F32 Q6K Matmul Layered shader"): + val program = Q6KMatmulVecProgram.Layered.forward(Q6KMatmulVecProgram.Layered.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048, + weightOffsetBytes = 0, + totalWeightBytes = 2048 * 8 * 210 // 8 blocks per row + )) + dumpProgram("f32_q6k_matmul_vec_layered", program) + + test("Dump F32 KV Cached Attention shader"): + val program = KVCachedAttention.forward(KVCachedAttention.Sizes( + B = 1, + T = 1, + NH = 32, + NKV = 4, + headSize = 64, + startPos = 0, + kCacheLayerOffset = 0, + vCacheLayerOffset = 0, + L = 1, + maxSeqLen = 2048, + )) + dumpProgram("f32_kv_cached_attention", program) + + // ============ F16 Programs ============ + + test("Dump F16 Embedding shader"): + val program = F16EmbeddingProgram.forward(F16EmbeddingProgram.Sizes( + seqLen = 2, + hiddenSize = 2048, + vocabSize = 32000 + )) + dumpProgram("f16_embedding", program) + + test("Dump F16 RMSNorm shader"): + val program = F16RMSNormProgram.forward(F16RMSNormProgram.Sizes( + numRows = 1, + rowSize = 2048, + eps = 1e-6f + )) + dumpProgram("f16_rmsnorm", program) + + test("Dump F16 RoPE shader"): + val program = F16RoPEProgram.forward(F16RoPEProgram.Sizes( + B = 1, + T = 2, + numHeads = 32, + headSize = 64, + theta = 10000f + )) + dumpProgram("f16_rope", program) + + test("Dump F16 MatmulVec Hybrid shader (Vec4 weights)"): + val program = F16MatmulVecHybridProgram.forward(F16MatmulVecHybridProgram.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048 + )) + dumpProgram("f16_matmul_vec_hybrid", program) + + test("Dump F16 SwiGLU shader"): + val program = F16SwiGLUProgram.forward(F16SwiGLUProgram.Sizes(5632)) + dumpProgram("f16_swiglu", program) + + test("Dump F16 ResidualAdd shader"): + val program = F16ResidualAddProgram.forward(F16ResidualAddProgram.Sizes(2048)) + dumpProgram("f16_residual_add", program) + + test("Dump F16 KV Cached Attention shader"): + val program = F16KVCachedAttention.forward(F16KVCachedAttention.Sizes( + B = 1, + T = 1, + NH = 32, + NKV = 4, + headSize = 64, + startPos = 0, + kCacheLayerOffset = 0, + vCacheLayerOffset = 0, + L = 1, + maxSeqLen = 2048, + )) + dumpProgram("f16_kv_cached_attention", program) + + test("Dump F16 Output Vec4 shader"): + val program = F16OutputVec4Program.forward(F16OutputVec4Program.Sizes( + batchSize = 1, + hiddenSize = 2048, + vocabSize = 32000 + )) + dumpProgram("f16_output_vec4", program) + + private def dumpProgram[P, L: Layout](name: String, program: GProgram[P, L]): Unit = + program match + case gioProgram: GioProgram[P, L] => + val layout = summon[Layout[L]] + val bindings = layout.toBindings(layout.layoutRef).toList + val shaderCode = DSLCompiler.compile(gioProgram.body(layout.layoutRef), bindings, gioProgram.workgroupSize) + + // Create runner with file outputs + val runner = SpirvToolsRunner( + validator = SpirvValidator.Enable(throwOnFail = false), + disassembler = SpirvDisassembler.Enable( + throwOnFail = false, + toolOutput = SpirvTool.ToFile(outputDir.resolve(s"$name.spvasm"), hashSuffix = false) + ), + crossCompilation = SpirvCross.Enable( + throwOnFail = false, + toolOutput = SpirvTool.ToFile(outputDir.resolve(s"$name.glsl"), hashSuffix = false), + settings = Seq(SpirvTool.Param("--vulkan-semantics")) + ), + originalSpirvOutput = SpirvTool.ToFile(outputDir.resolve(s"$name.spv"), hashSuffix = false) + ) + + runner.processShaderCodeWithSpirvTools(shaderCode) + println(s"Dumped $name shaders to $outputDir") + case _ => + println(s"Cannot dump $name - not a GioProgram") diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index 050bae1a..0786be77 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -31,9 +31,9 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]] private def compile[Params, L: Layout as l](program: GioProgram[Params, L]): SpirvProgram[Params, L] = - val GioProgram(_, layout, dispatch, _) = program + val GioProgram(_, layout, dispatch, workgroupSize) = program val bindings = l.toBindings(l.layoutRef).toList - val compiled = DSLCompiler.compile(program.body(l.layoutRef), bindings) + val compiled = DSLCompiler.compile(program.body(l.layoutRef), bindings, workgroupSize) val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) @@ -49,7 +49,7 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex context.destroy() object VkCyfraRuntime: - def using[T](f: VkCyfraRuntime ?=> T): T = - val runtime = new VkCyfraRuntime() + def using[T](f: VkCyfraRuntime ?=> T)(using spirvTools: SpirvToolsRunner = SpirvToolsRunner()): T = + val runtime = new VkCyfraRuntime(spirvTools) try f(using runtime) finally runtime.close()