Skip to content

[WIP][models][deepseek][qwen2.5] Add support for Qwen2.5 and Deepseek-Distilled-Qwen models #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.inference.state.Phi3State;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;

import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

import java.lang.foreign.MemorySegment;
Expand Down Expand Up @@ -176,6 +179,137 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
return state.logits;
}

public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) {
final Qwen2Configuration config = (Qwen2Configuration) model.configuration();
final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize();
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);

weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);

// forward all the layers
for (int l = 0; l < config.numberOfLayers(); l++) {
// attention rmsnorm
final int curLayer = l;
rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps());

// qkv matmuls for this position
weights.wq[l].matmul(state.xb, state.q, dim, dim);
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);

// qkv additions with qkv bias
state.q.addInPlace(weights.q_bias[curLayer]);
state.k.addInPlace(weights.k_bias[curLayer]);
state.v.addInPlace(weights.v_bias[curLayer]);

// RoPE relative positional encoding: complex-valued rotate q and k in each head
// GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive.
for (int h = 0; h < config.numberOfHeads(); ++h) {
int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
int poffset = h * headSize;
for (int i0 = 0; i0 < headSize; i0 += 2) {
int ic = i0 / 2;
float fcr = weights.freq_cis_real.getFloat((position) * (headSize / 2) + ic);
float fci = weights.freq_cis_imag.getFloat((position) * (headSize / 2) + ic);
for (int vi = 0; vi < rotn; vi++) {
FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
float v0 = vec.getFloat(poffset + ic);
float v1 = vec.getFloat(poffset + ic + headSize/2);
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
}
}
}

// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
state.k.copyTo(0, state.keyCache[curLayer], position * kvDim, kvDim);
state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim);

// multihead attention. iterate over all heads
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;

// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength();

// iterate over all timesteps, including the current one
for (int t = 0; t <= position; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att.setFloat(attOffset + t, score);
}

// softmax the scores to get attention weights, from 0..position inclusively
state.att.softmaxInPlace(attOffset, position + 1);

// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb.fillInPlace(xbOffset, headSize, 0f);

for (int t = 0; t <= position; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;C
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att.getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});

// final matmul to get the output of the attention
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);

// residual connection back into x
state.x.addInPlace(state.xb2);

// ffn rmsnorm
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps());

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);

// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));

// elementwise multiply with w3(x)
state.hb.multiplyInPlace(state.hb2);

// final matmul to get the output of the ffn
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());

// residual connection
state.x.addInPlace(state.xb);

}

// final rmsnorm
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());

// classifier into logits
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);

return state.logits;
}

public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
// a few convenience variables
final Qwen3Configuration config = (Qwen3Configuration) model.configuration();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.beehive.gpullama3.inference.state;

import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;

import java.util.stream.Stream;

public class Qwen2State extends State {

//Qwen2 specific fields TODO

public Qwen2State(Configuration config, int batchsize) {
super(config, batchsize);
// Initialize Qwen2-specific fields TODO
Qwen2Configuration qwen2Config = (Qwen2Configuration) config;
}
@Override
protected StateFields createStateFields(Configuration configuration) {
StateFields fields = new StateFields();

Qwen2Configuration config = (Qwen2Configuration) configuration;

int nEmbdGqa = config.kvDim();

// with Qwen2-specific sizes
fields.x = ArrayFloatTensor.allocate(config.dim());
fields.xb = ArrayFloatTensor.allocate(config.dim());
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
fields.q = ArrayFloatTensor.allocate(config.dim());
fields.k = ArrayFloatTensor.allocate(config.kvDim());
fields.v = ArrayFloatTensor.allocate(config.kvDim());
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());

// Key-value cache with Qwen2 dimensions
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

return fields;

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.beehive.gpullama3.inference.weights.standard;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.inference.weights.Weights;

public class Qwen2StandardWeights extends StandardWeights {
// Qwen2-specific weights
public final FloatTensor[] q_bias, k_bias, v_bias;

public Qwen2StandardWeights(
FloatTensor token_embedding_table,
FloatTensor[] rms_att_weight,
FloatTensor[] wq,
FloatTensor[] wk,
FloatTensor[] wv,
FloatTensor[] q_bias,
FloatTensor[] k_bias,
FloatTensor[] v_bias,
FloatTensor[] wo,
FloatTensor[] rms_ffn_weight,
FloatTensor[] w1,
FloatTensor[] w2,
FloatTensor[] w3,
FloatTensor rms_final_weight,
ArrayFloatTensor freq_cis_real,
ArrayFloatTensor freq_cis_imag,
FloatTensor wcls,
GGMLType weightType) {
// call to StandardWeights constructor
super(token_embedding_table,
rms_att_weight,
wq,
wk,
wv,
wo,
rms_ffn_weight,
w1,
w2,
w3,
rms_final_weight,
freq_cis_real,
freq_cis_imag,
wcls,
weightType);
// init Qwen2-specific fields
this.q_bias = q_bias;
this.k_bias = k_bias;
this.v_bias = v_bias;
}

@Override
public GGMLType getWeightType() {
return weightType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.beehive.gpullama3.inference.weights.tornado;

import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;

public class Qwen2TornadoWeights extends TornadoWeights {

// Qwen2-specific tornado weights
FloatArray[] q_biasLayered;
FloatArray[] k_biasLayered;
FloatArray[] v_biasLayered;

public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered,
FloatArray[] wqBiasLayered,
FloatArray[] wkBiasLayered,
FloatArray[] wvBiasLayered,
HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered,
HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray,
GGMLType weightType) {
// call to TornadoWeights constructor
super(tokenEmbeddingTable,
rms_att_weightLayered,
wqLayered,
wkLayered,
wvLayered,
woLayered,
rms_ffn_weightLayered,
w1Layered,
w2Layered,
w3Layered,
rms_final_weight_as_floatArray,
freq_cis_realFlat,
freq_cis_imagFlat,
wclsByteArray,
weightType);
// init qwen2-specific fields
this.q_biasLayered = wqBiasLayered;
this.k_biasLayered = wkBiasLayered;
this.v_biasLayered = wvBiasLayered;
}
}
Loading