1212import com .example .loader .weights .State ;
1313import com .example .tokenizer .impl .Tokenizer ;
1414import com .example .tornadovm .FloatArrayUtils ;
15+ import com .example .tornadovm .TornadoVMMasterPlan ;
1516import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
1617
1718import java .io .IOException ;
@@ -28,7 +29,7 @@ public class LlamaApp {
2829 public static final boolean USE_VECTOR_API = Boolean .parseBoolean (System .getProperty ("llama.VectorAPI" , "true" )); // Enable Java Vector API for CPU acceleration
2930 public static final boolean USE_AOT = Boolean .parseBoolean (System .getProperty ("llama.AOT" , "false" )); // Use Ahead-of-Time compilation
3031 public static final boolean USE_TORNADOVM = Boolean .parseBoolean (System .getProperty ("use.tornadovm" , "false" )); // Use TornadoVM for GPU acceleration
31-
32+ public static final boolean SHOW_PERF_INTERACTIVE = Boolean . parseBoolean ( System . getProperty ( "llama.ShowPerfInteractive" , "false" )); // Show performance metrics in interactive mode
3233 /**
3334 * Creates and configures a sampler for token generation based on specified parameters.
3435 *
@@ -114,6 +115,7 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
114115 return sampler ;
115116 }
116117
118+
117119 static void runInteractive (Llama model , Sampler sampler , Options options ) {
118120 State state = null ;
119121 List <Integer > conversationTokens = new ArrayList <>();
@@ -124,51 +126,92 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
124126 }
125127 int startPosition = 0 ;
126128 Scanner in = new Scanner (System .in );
127- while (true ) {
128- System .out .print ("> " );
129- System .out .flush ();
130- String userText = in .nextLine ();
131- if (List .of ("quit" , "exit" ).contains (userText )) {
132- break ;
133- }
134- if (state == null ) {
135- state = model .createNewState ();
136- }
137- conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , userText )));
138- conversationTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
139- Set <Integer > stopTokens = chatFormat .getStopTokens ();
140- List <Integer > responseTokens = Llama .generateTokens (model , state , startPosition , conversationTokens .subList (startPosition , conversationTokens .size ()), stopTokens , options .maxTokens (),
141- sampler , options .echo (), token -> {
142- if (options .stream ()) {
143- if (!model .tokenizer ().isSpecialToken (token )) {
144- System .out .print (model .tokenizer ().decode (List .of (token )));
145- }
129+
130+ // Initialize TornadoVM plan once at the beginning if GPU path is enabled
131+ TornadoVMMasterPlan tornadoVMPlan = null ;
132+
133+ try {
134+ while (true ) {
135+ System .out .print ("> " );
136+ System .out .flush ();
137+ String userText = in .nextLine ();
138+ if (List .of ("quit" , "exit" ).contains (userText )) {
139+ break ;
140+ }
141+ if (state == null ) {
142+ state = model .createNewState ();
143+ }
144+
145+ if (USE_TORNADOVM && tornadoVMPlan == null ) {
146+ tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , model );
147+ }
148+
149+ conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , userText )));
150+ conversationTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
151+ Set <Integer > stopTokens = chatFormat .getStopTokens ();
152+
153+ List <Integer > responseTokens ;
154+ IntConsumer tokenConsumer = token -> {
155+ if (options .stream ()) {
156+ if (!model .tokenizer ().isSpecialToken (token )) {
157+ System .out .print (model .tokenizer ().decode (List .of (token )));
146158 }
147- });
148- // Include stop token in the prompt history, but not in the response displayed to the user.
149- conversationTokens .addAll (responseTokens );
150- startPosition = conversationTokens .size ();
151- Integer stopToken = null ;
152- if (!responseTokens .isEmpty () && stopTokens .contains (responseTokens .getLast ())) {
153- stopToken = responseTokens .getLast ();
154- responseTokens .removeLast ();
155- }
156- if (!options .stream ()) {
157- String responseText = model .tokenizer ().decode (responseTokens );
158- System .out .println (responseText );
159+ }
160+ };
161+
162+ // Choose between GPU and CPU path based on configuration
163+ if (USE_TORNADOVM ) {
164+ // GPU path using TornadoVM
165+ responseTokens = Llama .generateTokensGPU (model , state , startPosition ,
166+ conversationTokens .subList (startPosition , conversationTokens .size ()),
167+ stopTokens , options .maxTokens (), sampler , options .echo (),
168+ options .stream () ? tokenConsumer : null , tornadoVMPlan );
169+ } else {
170+ // CPU path
171+ responseTokens = Llama .generateTokens (model , state , startPosition ,
172+ conversationTokens .subList (startPosition , conversationTokens .size ()),
173+ stopTokens , options .maxTokens (), sampler , options .echo (), tokenConsumer );
174+ }
175+
176+ // Include stop token in the prompt history, but not in the response displayed to the user.
177+ conversationTokens .addAll (responseTokens );
178+ startPosition = conversationTokens .size ();
179+ Integer stopToken = null ;
180+ if (!responseTokens .isEmpty () && stopTokens .contains (responseTokens .getLast ())) {
181+ stopToken = responseTokens .getLast ();
182+ responseTokens .removeLast ();
183+ }
184+ if (!options .stream ()) {
185+ String responseText = model .tokenizer ().decode (responseTokens );
186+ System .out .println (responseText );
187+ }
188+ if (stopToken == null ) {
189+ System .err .println ("Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX" );
190+ break ;
191+ }
192+ System .out .print ("\n " );
193+
194+ // Optionally print performance metrics after each response
195+ if (SHOW_PERF_INTERACTIVE ) {
196+ Llama .LastRunMetrics .printMetrics ();
197+ }
159198 }
160- if (stopToken == null ) {
161- System .err .println ("Ran out of context length..." );
162- break ;
199+ } finally {
200+ // Clean up TornadoVM resources when exiting the chat loop
201+ if (USE_TORNADOVM && tornadoVMPlan != null ) {
202+ try {
203+ tornadoVMPlan .freeTornadoExecutionPlan ();
204+ } catch (Exception e ) {
205+ System .err .println ("Error while cleaning up TornadoVM resources: " + e .getMessage ());
206+ }
163207 }
164- System .out .print ("\n " );
165-
166208 }
167209 }
168210
169211 static void runInstructOnce (Llama model , Sampler sampler , Options options ) {
170212 State state = model .createNewState ();
171213 ChatFormat chatFormat = new ChatFormat (model .tokenizer ());
214+ TornadoVMMasterPlan tornadoVMPlan =null ;
172215
173216 List <Integer > promptTokens = new ArrayList <>();
174217 promptTokens .add (chatFormat .beginOfText );
@@ -190,9 +233,10 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {
190233
191234 Set <Integer > stopTokens = chatFormat .getStopTokens ();
192235 if (USE_TORNADOVM ) {
236+ tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , model );
193237 // Call generateTokensGPU without the token consumer parameter
194238 responseTokens = Llama .generateTokensGPU (model , state , 0 , promptTokens , stopTokens , options .maxTokens (),
195- sampler , options .echo (), options .stream () ? tokenConsumer : null );
239+ sampler , options .echo (), options .stream () ? tokenConsumer : null , tornadoVMPlan );
196240 } else {
197241 // CPU path still uses the token consumer
198242 responseTokens = Llama .generateTokens (model , state , 0 , promptTokens , stopTokens , options .maxTokens (), sampler , options .echo (), tokenConsumer );
@@ -208,6 +252,9 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {
208252
209253 Llama .LastRunMetrics .printMetrics ();
210254
255+ if (tornadoVMPlan != null ) {
256+ tornadoVMPlan .freeTornadoExecutionPlan ();
257+ }
211258 }
212259
213260 public static void main (String [] args ) throws IOException {
0 commit comments