@@ -115,21 +115,34 @@ def _get_memory_spaces_from_avals(
115115 return memory_spaces
116116
117117
118- def _tc_lowering_rule (
118+ def pallas_call_tpu_lowering_rule (
119119 ctx : mlir .LoweringRuleContext ,
120120 * in_nodes ,
121121 jaxpr : jax_core .Jaxpr ,
122122 grid_mapping : pallas_core .GridMapping ,
123+ mesh : pallas_core .Mesh | None ,
123124 input_output_aliases : tuple [tuple [int , int ], ...],
124125 debug : bool ,
125- mesh : pallas_core .Mesh | None ,
126+ interpret : bool ,
127+ compiler_params : dict [str , pallas_core .CompilerParams ],
126128 cost_estimate : pallas_core .CostEstimate | None ,
127129 out_avals : tuple [jax_core .AbstractValue , ...],
128130 metadata : frozen_dict .FrozenDict [str , str ] | None ,
129- mosaic_params : tpu_core .CompilerParams ,
130- debug_info : jax_core .DebugInfo ,
131131 name : str | None ,
132132):
133+ """Lowers a pallas_call to a Mosaic TPU custom call."""
134+ del interpret # Unused.
135+
136+ debug_info = jaxpr .debug_info
137+ if debug :
138+ print (f"\n The kernel jaxpr for pallas_call { debug_info .func_src_info } :" )
139+ print (jaxpr )
140+
141+ if "mosaic_tpu" in compiler_params :
142+ mosaic_params = cast (tpu_core .CompilerParams , compiler_params ["mosaic_tpu" ])
143+ else :
144+ mosaic_params = tpu_core .CompilerParams ()
145+
133146 del mesh
134147 jax_mesh = None
135148 axis_context = ctx .module_context .axis_context
@@ -141,11 +154,21 @@ def _tc_lowering_rule(
141154 mlir_ctx .load_all_available_dialects ()
142155 tpu .register_dialect (mlir_ctx )
143156
157+ match mosaic_params .kernel_type :
158+ case tpu_core .KernelType .TC :
159+ lower_jaxpr_to_module = lowering .lower_jaxpr_to_module
160+ case tpu_core .KernelType .SC_SCALAR_SUBCORE | tpu_core .KernelType .SC_VECTOR_SUBCORE :
161+ lower_jaxpr_to_module = sc_lowering .lower_jaxpr_to_module
162+ case _:
163+ raise ValueError (
164+ f"Unsupported kernel type: { mosaic_params .kernel_type } "
165+ )
166+
144167 def lower_module (for_verification : bool ):
145168 if for_verification :
146169 mlir_ctx .allow_unregistered_dialects = True
147170 with mlir_ctx , ir .Location .unknown (mlir_ctx ):
148- return lowering . lower_jaxpr_to_module (
171+ return lower_jaxpr_to_module (
149172 ctx ,
150173 grid_mapping ,
151174 jaxpr ,
@@ -158,6 +181,8 @@ def lower_module(for_verification: bool):
158181
159182 mosaic_module = lower_module (for_verification = False )
160183 if debug :
184+ pm = passmanager .PassManager .parse ("builtin.module(canonicalize)" , mlir_ctx )
185+ pm .run (mosaic_module .operation )
161186 print (f"\n The Mosaic module for pallas_call { debug_info .func_src_info } :" )
162187 print (mosaic_module )
163188 num_dyn_bounds = grid_mapping .num_dynamic_grid_bounds
@@ -190,7 +215,7 @@ def lower_module(for_verification: bool):
190215 )
191216 dump_ctx = tempfile .NamedTemporaryFile (
192217 mode = "w" ,
193- prefix = mlir .sanitize_name (debug_info .func_name ) + "-" ,
218+ prefix = mlir .sanitize_name (name or debug_info .func_name ) + "-" ,
194219 suffix = ".pml" ,
195220 dir = promela_dump_path ,
196221 delete = False ,
@@ -308,123 +333,3 @@ def _maybe_cast_outputs(*args):
308333
309334 cast_ctx = ctx .replace (avals_in = kernel_out_avals )
310335 return mlir .lower_fun (_maybe_cast_outputs )(cast_ctx , * out_nodes )
311-
312-
313- # TODO(sharadmv,slebedev): Dedup with _tc_lowering_rule.
314- def _sc_lowering_rule (
315- ctx : mlir .LoweringRuleContext ,
316- * in_nodes ,
317- jaxpr : jax_core .Jaxpr ,
318- grid_mapping : pallas_core .GridMapping ,
319- mesh : pallas_core .Mesh | None ,
320- input_output_aliases : tuple [tuple [int , int ], ...],
321- debug : bool ,
322- mosaic_params : tpu_core .CompilerParams ,
323- cost_estimate : pallas_core .CostEstimate | None ,
324- out_avals : tuple [jax_core .AbstractValue , ...],
325- backend : str | None = None ,
326- metadata : frozen_dict .FrozenDict [str , str ] | None ,
327- debug_info : jax_core .DebugInfo ,
328- name : str | None ,
329- ):
330- """Lowers a pallas_call to a Mosaic SparseCore custom call."""
331- del mesh , out_avals , backend
332- out_shapes = grid_mapping .out_shapes
333-
334- jax_mesh = None
335- axis_context = ctx .module_context .axis_context
336- if axis_context is not None :
337- if isinstance (axis_context , sharding_impls .SPMDAxisContext ):
338- jax_mesh = axis_context .mesh
339-
340- with mlir .JaxIrContext () as mlir_ctx , ir .Location .unknown (mlir_ctx ):
341- mlir_ctx .append_dialect_registry (mlir .upstream_dialects )
342- mlir_ctx .load_all_available_dialects ()
343- tpu .register_dialect (mlir_ctx )
344- mosaic_module = sc_lowering .lower_jaxpr_to_module (
345- ctx , jaxpr , grid_mapping , mosaic_params , mesh = jax_mesh
346- )
347- if debug :
348- pm = passmanager .PassManager .parse ("builtin.module(canonicalize)" , mlir_ctx )
349- pm .run (mosaic_module .operation )
350- print (
351- "\n The Mosaic SparseCore module for pallas_call"
352- f" { debug_info .func_src_info } :"
353- )
354- print (mosaic_module )
355- out_avals = [jax_core .ShapedArray (s .shape , s .dtype ) for s in out_shapes ]
356- mosaic_cost_estimate = None
357- if cost_estimate is not None :
358- mosaic_cost_estimate = cast (
359- tpu_custom_call .CostEstimate , dataclasses .asdict (cost_estimate )
360- )
361-
362- def _lower_fun (* args ):
363- # Dynamic grid bounds have to go at the front.
364- out = mosaic .as_tpu_kernel (
365- mosaic_module ,
366- out_avals ,
367- kernel_name = mlir .sanitize_name (name or debug_info .func_name ),
368- cost_estimate = mosaic_cost_estimate ,
369- input_output_aliases = input_output_aliases ,
370- metadata = metadata ,
371- collective_id = mosaic_params .collective_id ,
372- _ir_version = tpu_custom_call .get_ir_version (ctx ),
373- )(* args )
374- return out
375-
376- return mlir .lower_fun (_lower_fun , multiple_results = True )(ctx , * in_nodes )
377-
378-
379- def pallas_call_tpu_lowering_rule (
380- ctx : mlir .LoweringRuleContext ,
381- * in_nodes ,
382- jaxpr : jax_core .Jaxpr ,
383- grid_mapping : pallas_core .GridMapping ,
384- mesh : pallas_core .Mesh | None ,
385- input_output_aliases : tuple [tuple [int , int ], ...],
386- debug : bool ,
387- interpret : bool ,
388- compiler_params : dict [str , pallas_core .CompilerParams ],
389- cost_estimate : pallas_core .CostEstimate | None ,
390- out_avals : tuple [jax_core .AbstractValue , ...],
391- metadata : frozen_dict .FrozenDict [str , str ] | None ,
392- name : str | None ,
393- ):
394- """Lowers a pallas_call to a Mosaic TPU custom call."""
395- del interpret # Unused.
396-
397- debug_info = jaxpr .debug_info
398- if debug :
399- print (f"\n The kernel jaxpr for pallas_call { debug_info .func_src_info } :" )
400- print (jaxpr )
401-
402- if "mosaic_tpu" in compiler_params :
403- mosaic_params = cast (tpu_core .CompilerParams , compiler_params ["mosaic_tpu" ])
404- else :
405- mosaic_params = tpu_core .CompilerParams ()
406- match mosaic_params .kernel_type :
407- case tpu_core .KernelType .TC :
408- lowering_rule = _tc_lowering_rule
409- case (
410- tpu_core .KernelType .SC_SCALAR_SUBCORE
411- | tpu_core .KernelType .SC_VECTOR_SUBCORE
412- ):
413- lowering_rule = _sc_lowering_rule
414- case _:
415- raise ValueError (f"Unsupported kernel type: { mosaic_params .kernel_type } " )
416- return lowering_rule (
417- ctx ,
418- * in_nodes ,
419- jaxpr = jaxpr ,
420- grid_mapping = grid_mapping ,
421- input_output_aliases = input_output_aliases ,
422- mesh = mesh ,
423- debug = debug ,
424- cost_estimate = cost_estimate ,
425- out_avals = out_avals ,
426- metadata = metadata ,
427- mosaic_params = mosaic_params ,
428- debug_info = debug_info ,
429- name = name ,
430- )
0 commit comments