44import  com .example .loader .weights .State ;
55import  com .example .model .Configuration ;
66import  com .example .model .Model ;
7+ import  com .example .model .ModelType ;
78import  uk .ac .manchester .tornado .api .GridScheduler ;
89import  uk .ac .manchester .tornado .api .ImmutableTaskGraph ;
910import  uk .ac .manchester .tornado .api .TornadoExecutionPlan ;
1213import  uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
1314
1415import  java .util .List ;
16+ import  java .util .Locale ;
1517
1618public  class  TornadoVMMasterPlan  {
1719    private  static  final  boolean  ENABLE_TORNADOVM_INIT_TIME  = Boolean .parseBoolean (System .getProperty ("llama.EnableTimingForTornadoVMInit" , "False" ));
@@ -22,9 +24,9 @@ public class TornadoVMMasterPlan {
2224    public  TornadoExecutionPlan  executionPlan ;
2325    List <ImmutableTaskGraph > taskGraphs ;
2426
25-     public  TornadoVMMasterPlan (State  state , Model  model ,  boolean   isNvidia ) {
27+     public  TornadoVMMasterPlan (State  state , Model  model ) {
2628        TornadoVMLayerPlanner  tornadoVMLayerPlanner  = new  TornadoVMLayerPlanner (state , model );
27-         Tuple2 <List <ImmutableTaskGraph >, GridScheduler > tornadoVMPlan  = isNvidia 
29+         Tuple2 <List <ImmutableTaskGraph >, GridScheduler > tornadoVMPlan  = shouldUseNvidiaScheduler ( model ) 
2830                ? tornadoVMLayerPlanner .setupTornadoForwardPlanLayered ()
2931                : tornadoVMLayerPlanner .setupTornadoForwardPlanLayeredNonNvidia ();
3032        this .taskGraphs  = tornadoVMPlan .getFirst ();
@@ -57,9 +59,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
5759        }
5860
5961        // 1. Pre-allocate the TornadoVM plan 
60-         TornadoRuntime  coreRuntime  = TornadoRuntimeProvider .getTornadoRuntime ();
61-         boolean  isNvidia  = coreRuntime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase ().contains ("nvidia" );
62-         TornadoVMMasterPlan  tornadoVMPlan  = new  TornadoVMMasterPlan (state , model , isNvidia );
62+         TornadoVMMasterPlan  tornadoVMPlan  = new  TornadoVMMasterPlan (state , model );
6363
6464        // Record time after plan creation 
6565        if  (ENABLE_TORNADOVM_INIT_TIME ) {
@@ -89,6 +89,29 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
8989        return  tornadoVMPlan ;
9090    }
9191
92+     /** 
93+      * Determines whether the NVIDIA-specific scheduler should be used based on the current 
94+      * hardware backend and the model type. 
95+      * <p> 
96+      * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model 
97+      * is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is 
98+      * {@code MISTRAL}, the NVIDIA-specific scheduler should not be used. 
99+      * 
100+      * @param model the model whose type may affect the scheduler decision 
101+      * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise 
102+      */ 
103+     public  static  boolean  shouldUseNvidiaScheduler (Model  model ) {
104+         TornadoRuntime  runtime  = TornadoRuntimeProvider .getTornadoRuntime ();
105+         String  platformName  = runtime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase (Locale .ROOT );
106+ 
107+         boolean  isNvidia  = platformName .contains ("nvidia" );
108+         boolean  isNotMistral  = model .getModelType () != ModelType .MISTRAL ;
109+ 
110+         boolean  result  = isNvidia  && isNotMistral ;
111+ 
112+         return  result ;
113+     }
114+ 
92115    /** 
93116     * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. 
94117     *This method processes the transformer layers in sequence for a particular token position in the context 
0 commit comments