@@ -192,7 +192,8 @@ public static FloatArray forwardTornadoVM( //
192192 return tornadoVMMasterPlan .tornadoVMForwardExecuteLayered (position );
193193 }
194194
195- public static List <Integer > generateTokensGPU (Llama model , State state , int startPosition , List <Integer > promptTokens , Set <Integer > stopTokens , int maxTokens , Sampler sampler , boolean echo ) {
195+ public static List <Integer > generateTokensGPU (Llama model , State state ,
196+ int startPosition , List <Integer > promptTokens , Set <Integer > stopTokens , int maxTokens , Sampler sampler , boolean echo , IntConsumer onTokenGenerated ) {
196197 // 1. Pre-allocate the TornadoVM plan just once
197198 TornadoVMMasterPlan tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , model );
198199
@@ -247,8 +248,13 @@ public static List<Integer> generateTokensGPU(Llama model, State state, int star
247248 // Sample next token - use GPU sampling if available
248249 nextToken = sampler .sampleToken (logits );
249250
251+ // Add token consumer support
252+ if (onTokenGenerated != null ) {
253+ onTokenGenerated .accept (nextToken );
254+ }
255+
250256 // Output if needed
251- if (echo ) {
257+ if (echo && onTokenGenerated == null ) {
252258 System .err .print (Tokenizer .replaceControlCharacters (model .tokenizer ().decode (List .of (nextToken ))));
253259 }
254260
@@ -359,23 +365,6 @@ public static List<Integer> generateTokens(Llama model, State state, int startPo
359365 return generatedTokens ;
360366 }
361367
362- /**
363- * Print performance metrics for the generation process
364- */
365- private static void printPerformanceMetrics (long startNanos , long inferenceStartNanos , int promptTokenCount , int generatedTokenCount ) {
366- long endNanos = System .nanoTime ();
367- long totalNanos = endNanos - startNanos ;
368- long inferenceNanos = inferenceStartNanos > 0 ? endNanos - inferenceStartNanos : 0 ;
369- long promptNanos = inferenceStartNanos - startNanos ;
370- int totalTokens = promptTokenCount + generatedTokenCount ;
371-
372- double totalTokensPerSecond = totalTokens / (totalNanos / 1_000_000_000.0 );
373- double promptTokensPerSecond = promptTokenCount > 0 ? promptTokenCount / (promptNanos / 1_000_000_000.0 ) : 0 ;
374- double inferenceTokensPerSecond = generatedTokenCount > 0 ? generatedTokenCount / (inferenceNanos / 1_000_000_000.0 ) : 0 ;
375-
376- System .err .printf ("\n %n%.2f tokens/s (%d) [PrEval %.2f tokens/s (%d), TokGen %.2f tokens/s (%d)]%n" , totalTokensPerSecond , totalTokens , promptTokensPerSecond , promptTokenCount ,
377- inferenceTokensPerSecond , generatedTokenCount );
378- }
379368
380369 public State createNewState () {
381370 State state = new State (configuration (), -1 );
0 commit comments