Skip to content

Commit d9a28ec

Browse files
committed
Enhance GPU token generation by adding token consumer support and removing unused performance metrics
1 parent a35b013 commit d9a28ec

File tree

2 files changed

+13
-29
lines changed

2 files changed

+13
-29
lines changed

src/main/java/com/example/LlamaApp.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import com.example.inference.engine.impl.Options;
1111
import com.example.loader.weights.ModelLoader;
1212
import com.example.loader.weights.State;
13+
import com.example.tokenizer.impl.Tokenizer;
1314
import com.example.tornadovm.FloatArrayUtils;
1415
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1516

@@ -160,6 +161,8 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
160161
System.err.println("Ran out of context length...");
161162
break;
162163
}
164+
System.out.print("\n");
165+
163166
}
164167
}
165168

@@ -188,16 +191,8 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {
188191
Set<Integer> stopTokens = chatFormat.getStopTokens();
189192
if (USE_TORNADOVM) {
190193
// Call generateTokensGPU without the token consumer parameter
191-
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo());
192-
// Handle token output separately if needed
193-
// You might need to iterate through responseTokens and process them
194-
if (options.stream()) {
195-
for (Integer token : responseTokens) {
196-
if (!model.tokenizer().isSpecialToken(token)) {
197-
System.out.print(model.tokenizer().decode(List.of(token)));
198-
}
199-
}
200-
}
194+
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
195+
sampler, options.echo(), options.stream() ? tokenConsumer : null);
201196
} else {
202197
// CPU path still uses the token consumer
203198
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);

src/main/java/com/example/inference/engine/impl/Llama.java

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ public static FloatArray forwardTornadoVM( //
192192
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
193193
}
194194

195-
public static List<Integer> generateTokensGPU(Llama model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo) {
195+
public static List<Integer> generateTokensGPU(Llama model, State state,
196+
int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
196197
// 1. Pre-allocate the TornadoVM plan just once
197198
TornadoVMMasterPlan tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
198199

@@ -247,8 +248,13 @@ public static List<Integer> generateTokensGPU(Llama model, State state, int star
247248
// Sample next token - use GPU sampling if available
248249
nextToken = sampler.sampleToken(logits);
249250

251+
// Add token consumer support
252+
if (onTokenGenerated != null) {
253+
onTokenGenerated.accept(nextToken);
254+
}
255+
250256
// Output if needed
251-
if (echo) {
257+
if (echo && onTokenGenerated == null) {
252258
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
253259
}
254260

@@ -359,23 +365,6 @@ public static List<Integer> generateTokens(Llama model, State state, int startPo
359365
return generatedTokens;
360366
}
361367

362-
/**
363-
* Print performance metrics for the generation process
364-
*/
365-
private static void printPerformanceMetrics(long startNanos, long inferenceStartNanos, int promptTokenCount, int generatedTokenCount) {
366-
long endNanos = System.nanoTime();
367-
long totalNanos = endNanos - startNanos;
368-
long inferenceNanos = inferenceStartNanos > 0 ? endNanos - inferenceStartNanos : 0;
369-
long promptNanos = inferenceStartNanos - startNanos;
370-
int totalTokens = promptTokenCount + generatedTokenCount;
371-
372-
double totalTokensPerSecond = totalTokens / (totalNanos / 1_000_000_000.0);
373-
double promptTokensPerSecond = promptTokenCount > 0 ? promptTokenCount / (promptNanos / 1_000_000_000.0) : 0;
374-
double inferenceTokensPerSecond = generatedTokenCount > 0 ? generatedTokenCount / (inferenceNanos / 1_000_000_000.0) : 0;
375-
376-
System.err.printf("\n%n%.2f tokens/s (%d) [PrEval %.2f tokens/s (%d), TokGen %.2f tokens/s (%d)]%n", totalTokensPerSecond, totalTokens, promptTokensPerSecond, promptTokenCount,
377-
inferenceTokensPerSecond, generatedTokenCount);
378-
}
379368

380369
public State createNewState() {
381370
State state = new State(configuration(), -1);

0 commit comments

Comments
 (0)