Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 31 additions & 126 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,34 @@ def _get_memory_spaces_from_avals(
return memory_spaces


def _tc_lowering_rule(
def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
mesh: pallas_core.Mesh | None,
interpret: bool,
compiler_params: dict[str, pallas_core.CompilerParams],
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
metadata: frozen_dict.FrozenDict[str, str] | None,
mosaic_params: tpu_core.CompilerParams,
debug_info: jax_core.DebugInfo,
name: str | None,
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret # Unused.

debug_info = jaxpr.debug_info
if debug:
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
print(jaxpr)

if "mosaic_tpu" in compiler_params:
mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"])
else:
mosaic_params = tpu_core.CompilerParams()

del mesh
jax_mesh = None
axis_context = ctx.module_context.axis_context
Expand All @@ -141,11 +154,21 @@ def _tc_lowering_rule(
mlir_ctx.load_all_available_dialects()
tpu.register_dialect(mlir_ctx)

match mosaic_params.kernel_type:
case tpu_core.KernelType.TC:
lower_jaxpr_to_module = lowering.lower_jaxpr_to_module
case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE:
lower_jaxpr_to_module = sc_lowering.lower_jaxpr_to_module
case _:
raise ValueError(
f"Unsupported kernel type: {mosaic_params.kernel_type}"
)

def lower_module(for_verification: bool):
if for_verification:
mlir_ctx.allow_unregistered_dialects = True
with mlir_ctx, ir.Location.unknown(mlir_ctx):
return lowering.lower_jaxpr_to_module(
return lower_jaxpr_to_module(
ctx,
grid_mapping,
jaxpr,
Expand All @@ -158,6 +181,8 @@ def lower_module(for_verification: bool):

mosaic_module = lower_module(for_verification=False)
if debug:
pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx)
pm.run(mosaic_module.operation)
print(f"\nThe Mosaic module for pallas_call {debug_info.func_src_info}:")
print(mosaic_module)
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
Expand Down Expand Up @@ -190,7 +215,7 @@ def lower_module(for_verification: bool):
)
dump_ctx = tempfile.NamedTemporaryFile(
mode="w",
prefix=mlir.sanitize_name(debug_info.func_name) + "-",
prefix=mlir.sanitize_name(name or debug_info.func_name) + "-",
suffix=".pml",
dir=promela_dump_path,
delete=False,
Expand Down Expand Up @@ -308,123 +333,3 @@ def _maybe_cast_outputs(*args):

cast_ctx = ctx.replace(avals_in=kernel_out_avals)
return mlir.lower_fun(_maybe_cast_outputs)(cast_ctx, *out_nodes)


# TODO(sharadmv,slebedev): Dedup with _tc_lowering_rule.
def _sc_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
mosaic_params: tpu_core.CompilerParams,
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: str | None = None,
metadata: frozen_dict.FrozenDict[str, str] | None,
debug_info: jax_core.DebugInfo,
name: str | None,
):
"""Lowers a pallas_call to a Mosaic SparseCore custom call."""
del mesh, out_avals, backend
out_shapes = grid_mapping.out_shapes

jax_mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
jax_mesh = axis_context.mesh

with mlir.JaxIrContext() as mlir_ctx, ir.Location.unknown(mlir_ctx):
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
mlir_ctx.load_all_available_dialects()
tpu.register_dialect(mlir_ctx)
mosaic_module = sc_lowering.lower_jaxpr_to_module(
ctx, jaxpr, grid_mapping, mosaic_params, mesh=jax_mesh
)
if debug:
pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx)
pm.run(mosaic_module.operation)
print(
"\nThe Mosaic SparseCore module for pallas_call"
f" {debug_info.func_src_info}:"
)
print(mosaic_module)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
mosaic_cost_estimate = None
if cost_estimate is not None:
mosaic_cost_estimate = cast(
tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate)
)

def _lower_fun(*args):
# Dynamic grid bounds have to go at the front.
out = mosaic.as_tpu_kernel(
mosaic_module,
out_avals,
kernel_name=mlir.sanitize_name(name or debug_info.func_name),
cost_estimate=mosaic_cost_estimate,
input_output_aliases=input_output_aliases,
metadata=metadata,
collective_id=mosaic_params.collective_id,
_ir_version=tpu_custom_call.get_ir_version(ctx),
)(*args)
return out

return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *in_nodes)


def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
compiler_params: dict[str, pallas_core.CompilerParams],
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
metadata: frozen_dict.FrozenDict[str, str] | None,
name: str | None,
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret # Unused.

debug_info = jaxpr.debug_info
if debug:
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
print(jaxpr)

if "mosaic_tpu" in compiler_params:
mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"])
else:
mosaic_params = tpu_core.CompilerParams()
match mosaic_params.kernel_type:
case tpu_core.KernelType.TC:
lowering_rule = _tc_lowering_rule
case (
tpu_core.KernelType.SC_SCALAR_SUBCORE
| tpu_core.KernelType.SC_VECTOR_SUBCORE
):
lowering_rule = _sc_lowering_rule
case _:
raise ValueError(f"Unsupported kernel type: {mosaic_params.kernel_type}")
return lowering_rule(
ctx,
*in_nodes,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
mesh=mesh,
debug=debug,
cost_estimate=cost_estimate,
out_avals=out_avals,
metadata=metadata,
mosaic_params=mosaic_params,
debug_info=debug_info,
name=name,
)
21 changes: 16 additions & 5 deletions jax/_src/pallas/mosaic/sc_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,24 @@ def dynamic_shape_replacement_fn(x):

def lower_jaxpr_to_module(
lowering_context: mlir.LoweringRuleContext,
jaxpr: jax_core.Jaxpr,
grid_mapping: pallas_core.GridMapping,
mosaic_params: tpu_core.CompilerParams,
jaxpr: jax_core.Jaxpr,
*,
dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None,
kernel_type: tpu_core.KernelType,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
dynamic_shape_replacement_enabled: bool = False,
) -> ir.Module:
"""Lowers a Jaxpr to a Mosaic SparseCore module."""
dimension_semantics = mosaic_params.dimension_semantics
if dynamic_shape_replacement_enabled:
raise NotImplementedError(
"Dynamic shape replacement is not supported for SparseCore."
)
if for_verification:
raise NotImplementedError(
"Verification is not supported for SparseCore."
)
if not grid_mapping.grid:
index_map_avals, index_map_tree = jax.tree.flatten(
((jax_core.ShapedArray((), jnp.int32),), {})
Expand Down Expand Up @@ -167,7 +178,7 @@ def new_index_map(*args, bm=bm):
func_op = lower_jaxpr_to_func(
jaxpr,
name="main",
kernel_type=mosaic_params.kernel_type,
kernel_type=kernel_type,
mosaic_grid_mapping=mosaic_grid_mapping,
forward_compatible=lowering_context.is_forward_compat(),
backend=backend,
Expand All @@ -188,7 +199,7 @@ def new_index_map(*args, bm=bm):
bm.block_aval,
name=func_name,
mosaic_grid_mapping=mosaic_grid_mapping,
kernel_type=mosaic_params.kernel_type,
kernel_type=kernel_type,
for_verification=False,
forward_compatible=lowering_context.is_forward_compat(),
backend=backend,
Expand Down
Loading