Skip to content

Commit f9639eb

Browse files
committed
Refactor GPU token generation to initialize TornadoVM plan once and clean up resources after use
1 parent d9a28ec commit f9639eb

File tree

2 files changed

+87
-45
lines changed

2 files changed

+87
-45
lines changed

src/main/java/com/example/LlamaApp.java

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.example.loader.weights.State;
1313
import com.example.tokenizer.impl.Tokenizer;
1414
import com.example.tornadovm.FloatArrayUtils;
15+
import com.example.tornadovm.TornadoVMMasterPlan;
1516
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1617

1718
import java.io.IOException;
@@ -28,7 +29,7 @@ public class LlamaApp {
2829
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
2930
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
3031
public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
31-
32+
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "false")); // Show performance metrics in interactive mode
3233
/**
3334
* Creates and configures a sampler for token generation based on specified parameters.
3435
*
@@ -114,6 +115,7 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
114115
return sampler;
115116
}
116117

118+
117119
static void runInteractive(Llama model, Sampler sampler, Options options) {
118120
State state = null;
119121
List<Integer> conversationTokens = new ArrayList<>();
@@ -124,51 +126,92 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
124126
}
125127
int startPosition = 0;
126128
Scanner in = new Scanner(System.in);
127-
while (true) {
128-
System.out.print("> ");
129-
System.out.flush();
130-
String userText = in.nextLine();
131-
if (List.of("quit", "exit").contains(userText)) {
132-
break;
133-
}
134-
if (state == null) {
135-
state = model.createNewState();
136-
}
137-
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
138-
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
139-
Set<Integer> stopTokens = chatFormat.getStopTokens();
140-
List<Integer> responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
141-
sampler, options.echo(), token -> {
142-
if (options.stream()) {
143-
if (!model.tokenizer().isSpecialToken(token)) {
144-
System.out.print(model.tokenizer().decode(List.of(token)));
145-
}
129+
130+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
131+
TornadoVMMasterPlan tornadoVMPlan = null;
132+
133+
try {
134+
while (true) {
135+
System.out.print("> ");
136+
System.out.flush();
137+
String userText = in.nextLine();
138+
if (List.of("quit", "exit").contains(userText)) {
139+
break;
140+
}
141+
if (state == null) {
142+
state = model.createNewState();
143+
}
144+
145+
if (USE_TORNADOVM && tornadoVMPlan == null) {
146+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
147+
}
148+
149+
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
150+
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
151+
Set<Integer> stopTokens = chatFormat.getStopTokens();
152+
153+
List<Integer> responseTokens;
154+
IntConsumer tokenConsumer = token -> {
155+
if (options.stream()) {
156+
if (!model.tokenizer().isSpecialToken(token)) {
157+
System.out.print(model.tokenizer().decode(List.of(token)));
146158
}
147-
});
148-
// Include stop token in the prompt history, but not in the response displayed to the user.
149-
conversationTokens.addAll(responseTokens);
150-
startPosition = conversationTokens.size();
151-
Integer stopToken = null;
152-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
153-
stopToken = responseTokens.getLast();
154-
responseTokens.removeLast();
155-
}
156-
if (!options.stream()) {
157-
String responseText = model.tokenizer().decode(responseTokens);
158-
System.out.println(responseText);
159+
}
160+
};
161+
162+
// Choose between GPU and CPU path based on configuration
163+
if (USE_TORNADOVM) {
164+
// GPU path using TornadoVM
165+
responseTokens = Llama.generateTokensGPU(model, state, startPosition,
166+
conversationTokens.subList(startPosition, conversationTokens.size()),
167+
stopTokens, options.maxTokens(), sampler, options.echo(),
168+
options.stream() ? tokenConsumer : null, tornadoVMPlan);
169+
} else {
170+
// CPU path
171+
responseTokens = Llama.generateTokens(model, state, startPosition,
172+
conversationTokens.subList(startPosition, conversationTokens.size()),
173+
stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
174+
}
175+
176+
// Include stop token in the prompt history, but not in the response displayed to the user.
177+
conversationTokens.addAll(responseTokens);
178+
startPosition = conversationTokens.size();
179+
Integer stopToken = null;
180+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
181+
stopToken = responseTokens.getLast();
182+
responseTokens.removeLast();
183+
}
184+
if (!options.stream()) {
185+
String responseText = model.tokenizer().decode(responseTokens);
186+
System.out.println(responseText);
187+
}
188+
if (stopToken == null) {
189+
System.err.println("Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX");
190+
break;
191+
}
192+
System.out.print("\n");
193+
194+
// Optionally print performance metrics after each response
195+
if (SHOW_PERF_INTERACTIVE) {
196+
Llama.LastRunMetrics.printMetrics();
197+
}
159198
}
160-
if (stopToken == null) {
161-
System.err.println("Ran out of context length...");
162-
break;
199+
} finally {
200+
// Clean up TornadoVM resources when exiting the chat loop
201+
if (USE_TORNADOVM && tornadoVMPlan != null) {
202+
try {
203+
tornadoVMPlan.freeTornadoExecutionPlan();
204+
} catch (Exception e) {
205+
System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
206+
}
163207
}
164-
System.out.print("\n");
165-
166208
}
167209
}
168210

169211
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
170212
State state = model.createNewState();
171213
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
214+
TornadoVMMasterPlan tornadoVMPlan =null;
172215

173216
List<Integer> promptTokens = new ArrayList<>();
174217
promptTokens.add(chatFormat.beginOfText);
@@ -190,9 +233,10 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {
190233

191234
Set<Integer> stopTokens = chatFormat.getStopTokens();
192235
if (USE_TORNADOVM) {
236+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
193237
// Call generateTokensGPU without the token consumer parameter
194238
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
195-
sampler, options.echo(), options.stream() ? tokenConsumer : null);
239+
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
196240
} else {
197241
// CPU path still uses the token consumer
198242
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
@@ -208,6 +252,9 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {
208252

209253
Llama.LastRunMetrics.printMetrics();
210254

255+
if (tornadoVMPlan != null) {
256+
tornadoVMPlan.freeTornadoExecutionPlan();
257+
}
211258
}
212259

213260
public static void main(String[] args) throws IOException {

src/main/java/com/example/inference/engine/impl/Llama.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,8 @@ public static FloatArray forwardTornadoVM( //
193193
}
194194

195195
public static List<Integer> generateTokensGPU(Llama model, State state,
196-
int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
197-
// 1. Pre-allocate the TornadoVM plan just once
198-
TornadoVMMasterPlan tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
199-
196+
int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated,
197+
TornadoVMMasterPlan tornadoVMPlan) {
200198
// === Setup and Initialization ===
201199
long startNanos = System.nanoTime();
202200
long inferenceStartNanos = 0;
@@ -281,9 +279,6 @@ public static List<Integer> generateTokensGPU(Llama model, State state,
281279
// Set metrics for tokens achieved
282280
LastRunMetrics.setMetrics(totalTokens, totalSeconds);
283281

284-
// Release GPU resources
285-
tornadoVMPlan.freeTornadoExecutionPlan();
286-
287282
return generatedTokens;
288283
}
289284

0 commit comments

Comments
 (0)