diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index c24a190..b4ad9b3 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -4,15 +4,18 @@ import org.beehive.gpullama3.core.model.tensor.FloatTensor; import org.beehive.gpullama3.inference.state.Phi3State; import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights; import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights; import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights; import org.beehive.gpullama3.inference.weights.standard.StandardWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.lang.foreign.MemorySegment; @@ -176,6 +179,137 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p return state.logits; } + public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) { + final Qwen2Configuration config = (Qwen2Configuration) model.configuration(); + final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery + float sqrtHeadSize = (float) Math.sqrt(headSize); + + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // forward all the layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // attention rmsnorm + final int curLayer = l; + rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps()); + + // qkv matmuls for this position + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // qkv additions with qkv bias + state.q.addInPlace(weights.q_bias[curLayer]); + state.k.addInPlace(weights.k_bias[curLayer]); + state.v.addInPlace(weights.v_bias[curLayer]); + + // RoPE relative positional encoding: complex-valued rotate q and k in each head + // GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive. + for (int h = 0; h < config.numberOfHeads(); ++h) { + int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + int poffset = h * headSize; + for (int i0 = 0; i0 < headSize; i0 += 2) { + int ic = i0 / 2; + float fcr = weights.freq_cis_real.getFloat((position) * (headSize / 2) + ic); + float fci = weights.freq_cis_imag.getFloat((position) * (headSize / 2) + ic); + for (int vi = 0; vi < rotn; vi++) { + FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key) + float v0 = vec.getFloat(poffset + ic); + float v1 = vec.getFloat(poffset + ic + headSize/2); + vec.setFloat(poffset + ic, v0 * fcr - v1 * fci); + vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr); + } + } + } + + // save key,value at this time step (position) to our kv cache + //int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience + state.k.copyTo(0, state.keyCache[curLayer], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim); + + // multihead attention. iterate over all heads + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + // get the query vector for this head + // float* q = s.q + h * headSize; + int qOffset = h * headSize; + + // attention scores for this head + // float* att = s.att + h * config.seq_len; + int attOffset = h * config.contextLength(); + + // iterate over all timesteps, including the current one + for (int t = 0; t <= position; t++) { + // get the key vector for this head and at this timestep + // float* k = s.key_cache + loff + t * dim + h * headSize; + int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; + // calculate the attention score as the dot product of q and k + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + // save the score to the attention buffer + state.att.setFloat(attOffset + t, score); + } + + // softmax the scores to get attention weights, from 0..position inclusively + state.att.softmaxInPlace(attOffset, position + 1); + + // weighted sum of the values, store back into xb + // float* xb = s.xb + h * headSize; + int xbOffset = h * headSize; + // memset(xb, 0, headSize * sizeof(float)); + state.xb.fillInPlace(xbOffset, headSize, 0f); + + for (int t = 0; t <= position; t++) { + // get the value vector for this head and at this timestep + // float* v = s.value_cache + loff + t * dim + h * headSize;C + int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; + // get the attention weight for this timestep + float a = state.att.getFloat(attOffset + t); + // accumulate the weighted value into xb + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // final matmul to get the output of the attention + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + + // residual connection back into x + state.x.addInPlace(state.xb2); + + // ffn rmsnorm + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps()); + + // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + // first calculate self.w1(x) and self.w3(x) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + + // SwiGLU non-linearity + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + + // elementwise multiply with w3(x) + state.hb.multiplyInPlace(state.hb2); + + // final matmul to get the output of the ffn + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // residual connection + state.x.addInPlace(state.xb); + + } + + // final rmsnorm + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); + + // classifier into logits + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); + + return state.logits; + } + public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) { // a few convenience variables final Qwen3Configuration config = (Qwen3Configuration) model.configuration(); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java new file mode 100644 index 0000000..3c475eb --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -0,0 +1,46 @@ +package org.beehive.gpullama3.inference.state; + +import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; +import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; + +import java.util.stream.Stream; + +public class Qwen2State extends State { + + //Qwen2 specific fields TODO + + public Qwen2State(Configuration config, int batchsize) { + super(config, batchsize); + // Initialize Qwen2-specific fields TODO + Qwen2Configuration qwen2Config = (Qwen2Configuration) config; + } + @Override + protected StateFields createStateFields(Configuration configuration) { + StateFields fields = new StateFields(); + + Qwen2Configuration config = (Qwen2Configuration) configuration; + + int nEmbdGqa = config.kvDim(); + + // with Qwen2-specific sizes + fields.x = ArrayFloatTensor.allocate(config.dim()); + fields.xb = ArrayFloatTensor.allocate(config.dim()); + fields.xb2 = ArrayFloatTensor.allocate(config.dim()); + fields.hb = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.q = ArrayFloatTensor.allocate(config.dim()); + fields.k = ArrayFloatTensor.allocate(config.kvDim()); + fields.v = ArrayFloatTensor.allocate(config.kvDim()); + fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + + // Key-value cache with Qwen2 dimensions + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + + return fields; + + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java new file mode 100644 index 0000000..fe401d0 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java @@ -0,0 +1,57 @@ +package org.beehive.gpullama3.inference.weights.standard; + +import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; +import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.inference.weights.Weights; + +public class Qwen2StandardWeights extends StandardWeights { + // Qwen2-specific weights + public final FloatTensor[] q_bias, k_bias, v_bias; + + public Qwen2StandardWeights( + FloatTensor token_embedding_table, + FloatTensor[] rms_att_weight, + FloatTensor[] wq, + FloatTensor[] wk, + FloatTensor[] wv, + FloatTensor[] q_bias, + FloatTensor[] k_bias, + FloatTensor[] v_bias, + FloatTensor[] wo, + FloatTensor[] rms_ffn_weight, + FloatTensor[] w1, + FloatTensor[] w2, + FloatTensor[] w3, + FloatTensor rms_final_weight, + ArrayFloatTensor freq_cis_real, + ArrayFloatTensor freq_cis_imag, + FloatTensor wcls, + GGMLType weightType) { + // call to StandardWeights constructor + super(token_embedding_table, + rms_att_weight, + wq, + wk, + wv, + wo, + rms_ffn_weight, + w1, + w2, + w3, + rms_final_weight, + freq_cis_real, + freq_cis_imag, + wcls, + weightType); + // init Qwen2-specific fields + this.q_bias = q_bias; + this.k_bias = k_bias; + this.v_bias = v_bias; + } + + @Override + public GGMLType getWeightType() { + return weightType; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java new file mode 100644 index 0000000..3d1a6ad --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.inference.weights.tornado; + +import org.beehive.gpullama3.core.model.GGMLType; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +public class Qwen2TornadoWeights extends TornadoWeights { + + // Qwen2-specific tornado weights + FloatArray[] q_biasLayered; + FloatArray[] k_biasLayered; + FloatArray[] v_biasLayered; + + public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, + FloatArray[] wqBiasLayered, + FloatArray[] wkBiasLayered, + FloatArray[] wvBiasLayered, + HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, + HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, + GGMLType weightType) { + // call to TornadoWeights constructor + super(tokenEmbeddingTable, + rms_att_weightLayered, + wqLayered, + wkLayered, + wvLayered, + woLayered, + rms_ffn_weightLayered, + w1Layered, + w2Layered, + w3Layered, + rms_final_weight_as_floatArray, + freq_cis_realFlat, + freq_cis_imagFlat, + wclsByteArray, + weightType); + // init qwen2-specific fields + this.q_biasLayered = wqBiasLayered; + this.k_biasLayered = wkBiasLayered; + this.v_biasLayered = wvBiasLayered; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/Model.java b/src/main/java/org/beehive/gpullama3/model/Model.java index d2242ea..8fa12ef 100644 --- a/src/main/java/org/beehive/gpullama3/model/Model.java +++ b/src/main/java/org/beehive/gpullama3/model/Model.java @@ -38,6 +38,18 @@ public interface Model { State createNewState(int batchsize); + default boolean shouldAddBeginOfText() { + return true; + } + + default boolean shouldAddSystemPrompt() { + return true; + } + + default boolean shouldIncludeReasoning() { + return false; + } + /** * Wrapper for invoking the model-specific forward pass via InferenceCore. * @@ -68,11 +80,11 @@ default void runInteractive(Sampler sampler, Options options) { ChatFormat chatFormat = chatFormat(); TornadoVMMasterPlan tornadoVMPlan = null; - if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.PHI_3)) { + if (shouldAddBeginOfText()) { conversationTokens.add(chatFormat.getBeginOfText()); } - if (options.systemPrompt() != null) { + if (shouldAddSystemPrompt() && options.systemPrompt() != null) { conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); } @@ -95,6 +107,18 @@ default void runInteractive(Sampler sampler, Options options) { conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + // Include reasoning for Deepseek-R1-Distill-Qwen + if (shouldIncludeReasoning()) { + List thinkStartTokens = tokenizer().encode("\n", tokenizer().getSpecialTokens().keySet()); + conversationTokens.addAll(thinkStartTokens); + + // If streaming, immediately output the think start + if (options.stream()) { + System.out.print("\n"); + } + } + Set stopTokens = chatFormat.getStopTokens(); List responseTokens; @@ -127,6 +151,10 @@ default void runInteractive(Sampler sampler, Options options) { } if (!options.stream()) { String responseText = tokenizer().decode(responseTokens); + // Add the forced \n prefix for non-streaming output + if (shouldIncludeReasoning()) { + responseText = "\n" + responseText; + } System.out.println(responseText); } if (stopToken == null) { @@ -164,11 +192,11 @@ default void runInstructOnce(Sampler sampler, Options options) { List promptTokens = new ArrayList<>(); - if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.PHI_3)) { + if (shouldAddBeginOfText()) { promptTokens.add(chatFormat.getBeginOfText()); } - if (options.systemPrompt() != null) { + if (shouldAddSystemPrompt() && options.systemPrompt() != null) { promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); } @@ -180,6 +208,17 @@ default void runInstructOnce(Sampler sampler, Options options) { promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + // Include reasoning for Deepseek-R1-Distill-Qwen + if (shouldIncludeReasoning()) { + List thinkStartTokens = tokenizer().encode("\n", tokenizer().getSpecialTokens().keySet()); + promptTokens.addAll(thinkStartTokens); + + // If streaming, immediately output the think start + if (options.stream()) { + System.out.print("\n"); + } + } + List responseTokens; IntConsumer tokenConsumer = token -> { @@ -206,6 +245,10 @@ default void runInstructOnce(Sampler sampler, Options options) { } if (!options.stream()) { String responseText = tokenizer().decode(responseTokens); + // Add the forced \n prefix for non-streaming output + if (shouldIncludeReasoning()) { + responseText = "\n" + responseText; + } System.out.println(responseText); } diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java index 741f48c..23d5146 100644 --- a/src/main/java/org/beehive/gpullama3/model/ModelType.java +++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.model.loader.LlamaModelLoader; import org.beehive.gpullama3.model.loader.MistralModelLoader; import org.beehive.gpullama3.model.loader.Phi3ModelLoader; +import org.beehive.gpullama3.model.loader.Qwen2ModelLoader; import org.beehive.gpullama3.model.loader.Qwen3ModelLoader; import java.nio.channels.FileChannel; @@ -35,6 +36,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo } }, + QWEN_2 { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel(); + } + }, + QWEN_3 { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { @@ -42,6 +50,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo } }, + DEEPSEEK_R1_DISTILL_QWEN { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel(); + } + }, + PHI_3 { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { @@ -58,4 +73,8 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo // Abstract method that each enum constant must implement public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights); + + public boolean isDeepSeekR1() { + return this == DEEPSEEK_R1_DISTILL_QWEN; + } } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index b5219c8..3da23ec 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -60,8 +60,12 @@ private static ModelType detectModelType(Map metadata) { return ModelType.MISTRAL; } else if (lowerName.contains("llama")) { return ModelType.LLAMA_3; + } else if (lowerName.contains("qwen2")) { + return ModelType.QWEN_2; } else if (lowerName.contains("qwen3")) { return ModelType.QWEN_3; + } else if (lowerName.contains("deepseek r1 distill")) { + return ModelType.DEEPSEEK_R1_DISTILL_QWEN; } else if (lowerName.contains("phi3")) { return ModelType.PHI_3; } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java new file mode 100644 index 0000000..2913129 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -0,0 +1,171 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.LlamaApp; +import org.beehive.gpullama3.auxiliary.Timer; +import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.core.model.GGUF; +import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; +import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; +import org.beehive.gpullama3.model.qwen2.Qwen2; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary; + +public class Qwen2ModelLoader extends ModelLoader { + + public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + super(fileChannel, gguf, contextLength, loadWeights); + } + + @Override + public Model loadModel() { + Map metadata = gguf.getMetadata(); + String basename = (String) metadata.get("general.basename"); + + String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename) + ? "DeepSeek-R1-Distill-Qwen" + : "Qwen2.5"; + + try (var ignored = Timer.log("Load " + modelName + " model")) { + // reuse method of Qwen3 + Vocabulary vocabulary = loadQwen3Vocabulary(metadata); + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); + + int modelContextLength = (int) metadata.get("qwen2.context_length"); + if (contextLength < 0 || modelContextLength < contextLength) { + contextLength = modelContextLength; + } + + int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") + ? (int) metadata.get("qwen2.attention.head_count_kv") + : (int) metadata.get("qwen2.attention.head_count"); + Qwen2Configuration config = new Qwen2Configuration( + (int) metadata.get("qwen2.embedding_length"), // dim + (int) metadata.get("qwen2.feed_forward_length"), // hiddendim + (int) metadata.get("qwen2.block_count"), // numberOfLayers + (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads + + numberOfKeyValueHeads, // numberOfKeyValueHeads + numberOfKeyValueHeads, // numberOfHeadsKey + numberOfKeyValueHeads, // numberOfHeadsValue + + vocabulary.size(), + modelContextLength, contextLength, + false, + (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), + (float) metadata.get("qwen2.rope.freq_base") + ); + + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + // Qwen2.5-Coder uses <|endoftext|> as stop-token. + ChatTokens chatTokens = isDeepSeekR1DistillQwen ? + new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") : + new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); + return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + // @formatter:off + @Override + public Weights loadWeights(Map tensorEntries, Configuration config) { + Pair ropeFreqs = RoPE.precomputeFreqsCis( + config.contextLengthModel(), + config.headSize(), + config.ropeTheta(), + false, + 8, + 1, + 3, + 8192 + ); + + GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); + GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); + + if (LlamaApp.USE_TORNADOVM) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); + return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } else { + return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } + } + + @Override + public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + return new Qwen2StandardWeights( + loadQuantized(tokenEmbeddings), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), + + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadQuantized(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadQuantized(outputWeight), + outputWeight.ggmlType()); + } + + @Override + public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + return new Qwen2TornadoWeights( + loadTensorAsFloatArray(tokenEmbeddings), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + // Qwen2-specific: qkv bias + loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), + loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), + loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), + + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 + floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), + FloatArray.fromArray(ropeFreqs.first()), + FloatArray.fromArray(ropeFreqs.second()), + loadTensorAsHalfFloatArray(outputWeight), + outputWeight.ggmlType() + ); + } + // @formatter:on + +} diff --git a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java index 0d4c2f4..1ee4ce4 100644 --- a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java +++ b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java @@ -53,6 +53,14 @@ public State createNewState(int batchsize) { return state; } + /** + * No begin of text needed for Phi3 models. + */ + @Override + public boolean shouldAddBeginOfText() { + return false; + } + @Override public void forward(State state, int token, int position) { if (plan == null) { diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java new file mode 100644 index 0000000..e8fcb58 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java @@ -0,0 +1,103 @@ +package org.beehive.gpullama3.model.qwen2; + +import org.beehive.gpullama3.inference.InferenceCore; +import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.AbstractModel; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +public class Qwen2 extends AbstractModel { + + Qwen2Configuration configuration; + + public Qwen2(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(tokenizer, weights, chatFormat, null); + this.configuration = configuration; + } + + public Qwen2Configuration configuration() { + return configuration; + } + + @Override + public Tokenizer tokenizer() { + return (Qwen3Tokenizer) tokenizer; + } + + @Override + public ModelType getModelType() { + return ModelType.QWEN_2; + } + + @Override + public State createNewState() { + State state = new Qwen2State(configuration(), -1); + state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader()); + return state; + } + + @Override + public State createNewState(int batchsize) { + State state = new Qwen2State(configuration(), batchsize); + state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader()); + return state; + } + + /** + * No <|beginoftext|> needed for Qwen models. + */ + @Override + public boolean shouldAddBeginOfText() { + return false; + } + + /** + * No system prompt for Deepseek-R1-Distill-Qwen. + * Based on Usage Recommendations + */ + @Override + public boolean shouldAddSystemPrompt() { + return !getModelType().isDeepSeekR1(); + } + + /** + * Force inclusion of for Deepseek-R1-Distill-Qwen. + * Based on Usage Recommendations + */ + @Override + public boolean shouldIncludeReasoning() { + return getModelType().isDeepSeekR1(); + } + + @Override + public void forward(State state, int token, int position) { + if (plan == null) { + InferenceCore.forwardJavaQwen2(this, state, token, position); + } else { + InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan()); + } + } + + @Override + public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + + @Override + public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2Configuration.java b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2Configuration.java new file mode 100644 index 0000000..5f4318d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2Configuration.java @@ -0,0 +1,37 @@ +package org.beehive.gpullama3.model.qwen2; + +import org.beehive.gpullama3.model.Configuration; + +public record Qwen2Configuration(int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int numberOfHeadsKey, + int numberOfHeadsValue, + int vocabularySize, + int contextLengthModel, + int contextLength, + boolean sharedWeights, + float rmsNormEps, + float ropeTheta) implements Configuration { + @Override + public int headSize() { + return dim / numberOfHeads; + } + + @Override + public int kvDim() { + return (dim * numberOfKeyValueHeads) / numberOfHeads; + } + + @Override + public int kvMul() { + throw new UnsupportedOperationException("Not supported for Qwen2."); + } + + @Override + public int contextLengthModel() { + return contextLengthModel; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java index b40d4d9..bf90c13 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java @@ -53,6 +53,14 @@ public State createNewState(int batchsize) { return state; } + /** + * No begin of text needed for Qwen models. + */ + @Override + public boolean shouldAddBeginOfText() { + return false; + } + @Override public void forward(State state, int token, int position) { if (plan == null) { diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java index f7494e6..bbf3574 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java @@ -53,8 +53,8 @@ public boolean isSpecialToken(int tokenIndex) { @Override public boolean shouldDisplayToken(int token) { int tokenType = getTokenType(token); - - return tokenType == 1 || tokenType == 6; + // tokenType 4 allows the display of reasoning ( ... <\think> ) + return tokenType == 1 || tokenType == 4 || tokenType == 6; } public int getTokenType(int tokenIndex) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java new file mode 100644 index 0000000..39e6238 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java @@ -0,0 +1,46 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.Tuple2; +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; + +import java.util.List; + +public class Qwen2TornadoVMLayerPlanner extends TornadoVMLayerPlanner { + + /** + * Constructs a TornadoVMLayerPlanner for the given Llama model. + * + * @param state + * The state object containing model tensors and buffers + * @param model + * The Llama model instance containing configuration and weights + */ + public Qwen2TornadoVMLayerPlanner(Qwen2State state, Model model) { + super(state, model); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + throw new UnsupportedOperationException("configureLayerDataTransfers Not supported yet."); + } + + @Override + public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() { + throw new UnsupportedOperationException("setupTornadoForwardPlanLayered Not supported yet."); + } + + @Override + public Tuple2, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() { + return setupTornadoForwardPlanLayered(); + } + + private GridScheduler setupQwen2GridSchedulersLayeredNonNvidia() { + throw new UnsupportedOperationException("setupQwen2GridSchedulersLayeredNonNvidia Not supported yet."); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index b7e5933..c0a07fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.state.Qwen2State; import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Configuration; @@ -93,7 +94,21 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod } /** - * Determines whether the NVIDIA-specific scheduler should be used based on the current hardware backend and the model type. + * Dispatcher method to select the TornadoVMLayerPlanner for the model. + */ + TornadoVMLayerPlanner createPlanner(State state, Model model) { + return switch (model.getModelType()) { + case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model); + case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model); + case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2TornadoVMLayerPlanner((Qwen2State) state, model); + case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model); + case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type"); + }; + } + + /** + * Determines whether the NVIDIA-specific scheduler should be used based on the current + * hardware backend and the model type. *

* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the * NVIDIA-specific scheduler should not be used. @@ -115,19 +130,8 @@ public static boolean shouldUseNvidiaScheduler(Model model) { } /** - * Dispatcher method to select the TornadoVMLayerPlanner for the model. - */ - TornadoVMLayerPlanner createPlanner(State state, Model model) { - return switch (model.getModelType()) { - case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model); - case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model); - case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model); - case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type"); - }; - } - - /** - * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context + * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. + *This method processes the transformer layers in sequence for a particular token position in the context * window. * *

The execution happens in three phases: