@@ -127,6 +127,10 @@ def parse_target(self, tgt_prop) -> dict:
127127 dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
128128 dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
129129
130+ if self .device_arch in self .device_props :
131+ dev_prop .update (self .device_props [self .device_arch ])
132+ return dev_prop
133+
130134 return dev_prop
131135
132136 def parse_options (self , opts ) -> Any :
@@ -203,84 +207,31 @@ def get_split_barrier_scope(opt):
203207 return split_barriers_scope
204208
205209 @classmethod
206- def create_pass_manager (cls , context , add_passes = [] ):
210+ def create_pass_manager (cls , context ):
207211 pm = ir .pass_manager (context )
208212 pm .enable_debug ()
209- for p in add_passes :
210- if p is None :
211- continue
212- elif isinstance (p , tuple ):
213- p [0 ](pm , * p [1 :])
214- else :
215- p (pm )
216213 return pm
217214
218- @classmethod
219- def get_ttir_passes (cls , opt ):
220- return [
221- passes .common .add_inliner ,
222- intel .passes .ttir .add_convert_tdesc_to_block_pointer ,
223- passes .ttir .add_rewrite_tensor_descriptor_to_pointer ,
224- passes .common .add_cse ,
225- passes .common .add_licm ,
226- intel .passes .ttir .add_remove_masks ,
227- intel .passes .ttir .add_fuse_reshape ,
228- passes .common .add_canonicalizer ,
229- passes .ttir .add_combine ,
230- passes .ttir .add_reorder_broadcast ,
231- passes .common .add_cse ,
232- passes .common .add_symbol_dce ,
233- passes .ttir .add_loop_unroll ,
234- ]
235-
236215 @classmethod
237216 @track
238217 def make_ttir (cls , mod , metadata , opt ):
239- pm = cls .create_pass_manager (mod .context , cls .get_ttir_passes (opt ))
218+ pm = cls .create_pass_manager (mod .context )
219+ passes .common .add_inliner (pm )
220+ intel .passes .ttir .add_convert_tdesc_to_block_pointer (pm )
221+ passes .ttir .add_rewrite_tensor_descriptor_to_pointer (pm )
222+ passes .common .add_cse (pm )
223+ passes .common .add_licm (pm )
224+ intel .passes .ttir .add_remove_masks (pm )
225+ intel .passes .ttir .add_fuse_reshape (pm )
226+ passes .common .add_canonicalizer (pm )
227+ passes .ttir .add_combine (pm )
228+ passes .ttir .add_reorder_broadcast (pm )
229+ passes .common .add_cse (pm )
230+ passes .common .add_symbol_dce (pm )
231+ passes .ttir .add_loop_unroll (pm )
240232 pm .run (mod , 'make_ttir' )
241233 return mod
242234
243- @classmethod
244- def get_ttgir_passes (cls , opt ):
245- # fmt: off
246- return [
247- (passes .ttir .add_convert_to_ttgpuir , "xpu" , opt .num_warps , opt .warp_size , opt .num_ctas ),
248- # optimize TTGIR
249- intel .passes .ttgpuir .add_coalesce ,
250- intel .passes .ttgpuir .add_remove_layout_conversions ,
251-
252- intel .passes .ttgpuir .add_accelerate_matmul ,
253- intel .passes .ttgpuir .add_materialize_block_pointer ,
254- intel .passes .ttgpuir .add_remove_layout_conversions ,
255- intel .passes .ttgpuir .add_optimize_dot_operands ,
256- (intel .passes .ttgpuir .add_pipeline , opt .num_stages , cls .get_split_barrier_scope (opt )),
257-
258- intel .passes .ttgpuir .add_reduce_variable_liveness if opt .reduce_variable_liveness else None ,
259-
260- passes .ttgpuir .add_fuse_nested_loops ,
261-
262- passes .common .add_canonicalizer ,
263- passes .ttir .add_triton_licm ,
264- passes .common .add_canonicalizer ,
265- passes .ttgpuir .add_combine_tensor_select_and_if ,
266-
267- passes .ttgpuir .add_optimize_thread_locality ,
268- (passes .ttgpuir .add_optimize_dot_operands , True ),
269- passes .common .add_cse ,
270- passes .ttgpuir .add_prefetch ,
271- (passes .ttgpuir .add_optimize_dot_operands , True ),
272- intel .passes .ttgpuir .add_remove_layout_conversions ,
273- intel .passes .ttgpuir .add_reduce_data_duplication ,
274- passes .ttgpuir .add_reorder_instructions ,
275- passes .common .add_cse ,
276- passes .common .add_symbol_dce ,
277- passes .common .add_sccp ,
278- passes .common .add_canonicalizer ,
279- intel .passes .ttgpuir .add_optimize_reduction_locality if knobs .intel .opt_reduction_locality else None ,
280- (intel .passes .arith .add_arith_emulate_unsupported_floats , ["bf16" ], "f32" )
281- ]
282- # fmt: on
283-
284235 @classmethod
285236 @track
286237 def make_ttgir (cls , mod , metadata , opt , properties ):
@@ -301,7 +252,43 @@ def make_ttgir(cls, mod, metadata, opt, properties):
301252 opt .warp_size = intel .get_threads_per_warp (mod )
302253 cls .validate_options (opt , properties )
303254
304- pm = cls .create_pass_manager (mod .context , cls .get_ttgir_passes (opt ))
255+ pm = cls .create_pass_manager (mod .context )
256+ passes .ttir .add_convert_to_ttgpuir (pm , "xpu" , opt .num_warps , opt .warp_size , opt .num_ctas )
257+ # optimize TTGIR
258+ intel .passes .ttgpuir .add_coalesce (pm )
259+ intel .passes .ttgpuir .add_remove_layout_conversions (pm )
260+
261+ intel .passes .ttgpuir .add_accelerate_matmul (pm )
262+ intel .passes .ttgpuir .add_materialize_block_pointer (pm )
263+ intel .passes .ttgpuir .add_remove_layout_conversions (pm )
264+ intel .passes .ttgpuir .add_optimize_dot_operands (pm )
265+ intel .passes .ttgpuir .add_pipeline (pm , opt .num_stages , XPUBackend .get_split_barrier_scope (opt ))
266+
267+ if (opt .reduce_variable_liveness ):
268+ intel .passes .ttgpuir .add_reduce_variable_liveness (pm )
269+
270+ passes .ttgpuir .add_fuse_nested_loops (pm )
271+
272+ passes .common .add_canonicalizer (pm )
273+ passes .ttir .add_triton_licm (pm )
274+ passes .common .add_canonicalizer (pm )
275+ passes .ttgpuir .add_combine_tensor_select_and_if (pm )
276+
277+ passes .ttgpuir .add_optimize_thread_locality (pm )
278+ passes .ttgpuir .add_optimize_dot_operands (pm , True )
279+ passes .common .add_cse (pm )
280+ passes .ttgpuir .add_prefetch (pm )
281+ passes .ttgpuir .add_optimize_dot_operands (pm , True )
282+ intel .passes .ttgpuir .add_remove_layout_conversions (pm )
283+ intel .passes .ttgpuir .add_reduce_data_duplication (pm )
284+ passes .ttgpuir .add_reorder_instructions (pm )
285+ passes .common .add_cse (pm )
286+ passes .common .add_symbol_dce (pm )
287+ passes .common .add_sccp (pm )
288+ passes .common .add_canonicalizer (pm )
289+ if knobs .intel .opt_reduction_locality :
290+ intel .passes .ttgpuir .add_optimize_reduction_locality (pm )
291+ intel .passes .arith .add_arith_emulate_unsupported_floats (pm , ["bf16" ], "f32" )
305292 pm .run (mod , 'make_ttgir' )
306293 metadata ["cluster_dims" ] = (cluster_info .clusterDimX , cluster_info .clusterDimY , cluster_info .clusterDimZ )
307294 return mod
@@ -322,31 +309,6 @@ def gluon_to_ttgir(self, src, metadata, options):
322309 metadata ["tensordesc_meta" ] = mod .get_tensordesc_metadata ()
323310 return mod
324311
325- @classmethod
326- def get_llir_passes (cls , opt , mod ):
327- # fmt: off
328- return [
329- passes .convert .add_scf_to_cf ,
330- passes .gluon .add_inliner ,
331- passes .convert .add_index_to_llvmir ,
332- intel .passes .ttgpuir .add_allocate_shared_memory ,
333- passes .ttgpuir .add_allocate_global_scratch_memory ,
334- # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
335- lambda pm : cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context ) if cls .instrumentation else None ,
336- intel .passes .ttgpuir .add_to_llvmir ,
337- intel .passes .ttgpuir .add_gen_to_llvm ,
338- passes .common .add_canonicalizer ,
339- intel .passes .ttgpuir .add_rewrite_stack_ptr ,
340- passes .common .add_cse ,
341- passes .convert .add_arith_to_llvmir ,
342- passes .common .add_canonicalizer ,
343- passes .common .add_cse ,
344- passes .common .add_symbol_dce ,
345- None if knobs .compilation .disable_line_info or knobs .compilation .dump_ir_extract_di_local_variables else passes .llvmir .add_di_scope ,
346- lambda pm : cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context ) if cls .instrumentation else None ,
347- ]
348- # fmt: on
349-
350312 @classmethod
351313 def optimize_llvm_mod (cls , llvm_mod , options ):
352314 intel .set_spv_target_triple (llvm_mod )
@@ -358,21 +320,46 @@ def optimize_llvm_mod(cls, llvm_mod, options):
358320 def make_llir (cls , src , metadata , options ):
359321 mod = src
360322 # TritonGPU -> LLVM-IR (MLIR)
361- pm = cls .create_pass_manager (mod .context , cls .get_llir_passes (options , mod ))
323+ pm = cls .create_pass_manager (mod .context )
324+ passes .convert .add_scf_to_cf (pm )
325+ passes .gluon .add_inliner (pm )
326+ passes .convert .add_index_to_llvmir (pm )
327+ intel .passes .ttgpuir .add_allocate_shared_memory (pm )
328+ passes .ttgpuir .add_allocate_global_scratch_memory (pm )
329+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
330+ if cls .instrumentation :
331+ cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
332+ intel .passes .ttgpuir .add_to_llvmir (pm )
333+ intel .passes .ttgpuir .add_gen_to_llvm (pm )
334+ passes .common .add_canonicalizer (pm )
335+ intel .passes .ttgpuir .add_rewrite_stack_ptr (pm )
336+ passes .common .add_cse (pm )
337+ passes .convert .add_arith_to_llvmir (pm )
338+ passes .common .add_canonicalizer (pm )
339+ passes .common .add_cse (pm )
340+ passes .common .add_symbol_dce (pm )
341+
342+ if not knobs .compilation .disable_line_info and not knobs .compilation .dump_ir_extract_di_local_variables :
343+ passes .llvmir .add_di_scope (pm )
344+
345+ if cls .instrumentation :
346+ cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
362347 pm .run (mod , 'make_llir' )
363348
364349 if knobs .compilation .dump_ir_extract_di_local_variables :
365350 # comments below on why separate it
366351 if not knobs .compilation .disable_line_info :
367- pm = cls .create_pass_manager (mod .context , [passes .llvmir .add_di_scope ])
352+ pm = cls .create_pass_manager (mod .context )
353+ passes .llvmir .add_di_scope (pm )
368354 pm .run (mod , 'make_llir.disable_line_info' )
369355
370356 # insert dbg intrinsic with several DI Attribute including source
371357 # var name and type info note: unknown reason for now, but this
372358 # pass and add_di_scope has to be run separately, otherwise if we
373359 # put them into previous pipline, it trigger a segmentfault without
374360 # any error message; could be due to a bug in mlir or pybind11
375- pm = cls .create_pass_manager (mod .context , [passes .llvmir .add_di_local_variable ])
361+ pm = cls .create_pass_manager (mod .context )
362+ passes .llvmir .add_di_local_variable (pm )
376363 pm .run (mod , 'make_llir.dump_ir_extract_di_local_variables' )
377364
378365 # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
0 commit comments