diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 9477ca5d1e41..bde9ca333096 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -115,21 +115,34 @@ def _get_memory_spaces_from_avals( return memory_spaces -def _tc_lowering_rule( +def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - mesh: pallas_core.Mesh | None, + interpret: bool, + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], metadata: frozen_dict.FrozenDict[str, str] | None, - mosaic_params: tpu_core.CompilerParams, - debug_info: jax_core.DebugInfo, name: str | None, ): + """Lowers a pallas_call to a Mosaic TPU custom call.""" + del interpret # Unused. + + debug_info = jaxpr.debug_info + if debug: + print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") + print(jaxpr) + + if "mosaic_tpu" in compiler_params: + mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"]) + else: + mosaic_params = tpu_core.CompilerParams() + del mesh jax_mesh = None axis_context = ctx.module_context.axis_context @@ -141,11 +154,21 @@ def _tc_lowering_rule( mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) + match mosaic_params.kernel_type: + case tpu_core.KernelType.TC: + lower_jaxpr_to_module = lowering.lower_jaxpr_to_module + case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE: + lower_jaxpr_to_module = sc_lowering.lower_jaxpr_to_module + case _: + raise ValueError( + f"Unsupported kernel type: {mosaic_params.kernel_type}" + ) + def lower_module(for_verification: bool): if for_verification: mlir_ctx.allow_unregistered_dialects = True with mlir_ctx, ir.Location.unknown(mlir_ctx): - return lowering.lower_jaxpr_to_module( + return lower_jaxpr_to_module( ctx, grid_mapping, jaxpr, @@ -158,6 +181,8 @@ def lower_module(for_verification: bool): mosaic_module = lower_module(for_verification=False) if debug: + pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx) + pm.run(mosaic_module.operation) print(f"\nThe Mosaic module for pallas_call {debug_info.func_src_info}:") print(mosaic_module) num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds @@ -190,7 +215,7 @@ def lower_module(for_verification: bool): ) dump_ctx = tempfile.NamedTemporaryFile( mode="w", - prefix=mlir.sanitize_name(debug_info.func_name) + "-", + prefix=mlir.sanitize_name(name or debug_info.func_name) + "-", suffix=".pml", dir=promela_dump_path, delete=False, @@ -308,123 +333,3 @@ def _maybe_cast_outputs(*args): cast_ctx = ctx.replace(avals_in=kernel_out_avals) return mlir.lower_fun(_maybe_cast_outputs)(cast_ctx, *out_nodes) - - -# TODO(sharadmv,slebedev): Dedup with _tc_lowering_rule. -def _sc_lowering_rule( - ctx: mlir.LoweringRuleContext, - *in_nodes, - jaxpr: jax_core.Jaxpr, - grid_mapping: pallas_core.GridMapping, - mesh: pallas_core.Mesh | None, - input_output_aliases: tuple[tuple[int, int], ...], - debug: bool, - mosaic_params: tpu_core.CompilerParams, - cost_estimate: pallas_core.CostEstimate | None, - out_avals: tuple[jax_core.AbstractValue, ...], - backend: str | None = None, - metadata: frozen_dict.FrozenDict[str, str] | None, - debug_info: jax_core.DebugInfo, - name: str | None, -): - """Lowers a pallas_call to a Mosaic SparseCore custom call.""" - del mesh, out_avals, backend - out_shapes = grid_mapping.out_shapes - - jax_mesh = None - axis_context = ctx.module_context.axis_context - if axis_context is not None: - if isinstance(axis_context, sharding_impls.SPMDAxisContext): - jax_mesh = axis_context.mesh - - with mlir.JaxIrContext() as mlir_ctx, ir.Location.unknown(mlir_ctx): - mlir_ctx.append_dialect_registry(mlir.upstream_dialects) - mlir_ctx.load_all_available_dialects() - tpu.register_dialect(mlir_ctx) - mosaic_module = sc_lowering.lower_jaxpr_to_module( - ctx, jaxpr, grid_mapping, mosaic_params, mesh=jax_mesh - ) - if debug: - pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx) - pm.run(mosaic_module.operation) - print( - "\nThe Mosaic SparseCore module for pallas_call" - f" {debug_info.func_src_info}:" - ) - print(mosaic_module) - out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] - mosaic_cost_estimate = None - if cost_estimate is not None: - mosaic_cost_estimate = cast( - tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate) - ) - - def _lower_fun(*args): - # Dynamic grid bounds have to go at the front. - out = mosaic.as_tpu_kernel( - mosaic_module, - out_avals, - kernel_name=mlir.sanitize_name(name or debug_info.func_name), - cost_estimate=mosaic_cost_estimate, - input_output_aliases=input_output_aliases, - metadata=metadata, - collective_id=mosaic_params.collective_id, - _ir_version=tpu_custom_call.get_ir_version(ctx), - )(*args) - return out - - return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *in_nodes) - - -def pallas_call_tpu_lowering_rule( - ctx: mlir.LoweringRuleContext, - *in_nodes, - jaxpr: jax_core.Jaxpr, - grid_mapping: pallas_core.GridMapping, - mesh: pallas_core.Mesh | None, - input_output_aliases: tuple[tuple[int, int], ...], - debug: bool, - interpret: bool, - compiler_params: dict[str, pallas_core.CompilerParams], - cost_estimate: pallas_core.CostEstimate | None, - out_avals: tuple[jax_core.AbstractValue, ...], - metadata: frozen_dict.FrozenDict[str, str] | None, - name: str | None, -): - """Lowers a pallas_call to a Mosaic TPU custom call.""" - del interpret # Unused. - - debug_info = jaxpr.debug_info - if debug: - print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") - print(jaxpr) - - if "mosaic_tpu" in compiler_params: - mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"]) - else: - mosaic_params = tpu_core.CompilerParams() - match mosaic_params.kernel_type: - case tpu_core.KernelType.TC: - lowering_rule = _tc_lowering_rule - case ( - tpu_core.KernelType.SC_SCALAR_SUBCORE - | tpu_core.KernelType.SC_VECTOR_SUBCORE - ): - lowering_rule = _sc_lowering_rule - case _: - raise ValueError(f"Unsupported kernel type: {mosaic_params.kernel_type}") - return lowering_rule( - ctx, - *in_nodes, - jaxpr=jaxpr, - grid_mapping=grid_mapping, - input_output_aliases=input_output_aliases, - mesh=mesh, - debug=debug, - cost_estimate=cost_estimate, - out_avals=out_avals, - metadata=metadata, - mosaic_params=mosaic_params, - debug_info=debug_info, - name=name, - ) diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index d0f5d41be441..d07d1ecbda7e 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -97,13 +97,24 @@ def dynamic_shape_replacement_fn(x): def lower_jaxpr_to_module( lowering_context: mlir.LoweringRuleContext, - jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, - mosaic_params: tpu_core.CompilerParams, + jaxpr: jax_core.Jaxpr, + *, + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, + kernel_type: tpu_core.KernelType, mesh: mesh_lib.Mesh | None = None, + for_verification: bool = False, + dynamic_shape_replacement_enabled: bool = False, ) -> ir.Module: """Lowers a Jaxpr to a Mosaic SparseCore module.""" - dimension_semantics = mosaic_params.dimension_semantics + if dynamic_shape_replacement_enabled: + raise NotImplementedError( + "Dynamic shape replacement is not supported for SparseCore." + ) + if for_verification: + raise NotImplementedError( + "Verification is not supported for SparseCore." + ) if not grid_mapping.grid: index_map_avals, index_map_tree = jax.tree.flatten( ((jax_core.ShapedArray((), jnp.int32),), {}) @@ -167,7 +178,7 @@ def new_index_map(*args, bm=bm): func_op = lower_jaxpr_to_func( jaxpr, name="main", - kernel_type=mosaic_params.kernel_type, + kernel_type=kernel_type, mosaic_grid_mapping=mosaic_grid_mapping, forward_compatible=lowering_context.is_forward_compat(), backend=backend, @@ -188,7 +199,7 @@ def new_index_map(*args, bm=bm): bm.block_aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, - kernel_type=mosaic_params.kernel_type, + kernel_type=kernel_type, for_verification=False, forward_compatible=lowering_context.is_forward_compat(), backend=backend,