Skip to content

Commit 8b5b60a

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:sc] Nuked _sc_lowering in favor of _tc_lowering
PiperOrigin-RevId: 830792145
1 parent f110351 commit 8b5b60a

File tree

2 files changed

+44
-132
lines changed

2 files changed

+44
-132
lines changed

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 31 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,34 @@ def _get_memory_spaces_from_avals(
115115
return memory_spaces
116116

117117

118-
def _tc_lowering_rule(
118+
def pallas_call_tpu_lowering_rule(
119119
ctx: mlir.LoweringRuleContext,
120120
*in_nodes,
121121
jaxpr: jax_core.Jaxpr,
122122
grid_mapping: pallas_core.GridMapping,
123+
mesh: pallas_core.Mesh | None,
123124
input_output_aliases: tuple[tuple[int, int], ...],
124125
debug: bool,
125-
mesh: pallas_core.Mesh | None,
126+
interpret: bool,
127+
compiler_params: dict[str, pallas_core.CompilerParams],
126128
cost_estimate: pallas_core.CostEstimate | None,
127129
out_avals: tuple[jax_core.AbstractValue, ...],
128130
metadata: frozen_dict.FrozenDict[str, str] | None,
129-
mosaic_params: tpu_core.CompilerParams,
130-
debug_info: jax_core.DebugInfo,
131131
name: str | None,
132132
):
133+
"""Lowers a pallas_call to a Mosaic TPU custom call."""
134+
del interpret # Unused.
135+
136+
debug_info = jaxpr.debug_info
137+
if debug:
138+
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
139+
print(jaxpr)
140+
141+
if "mosaic_tpu" in compiler_params:
142+
mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"])
143+
else:
144+
mosaic_params = tpu_core.CompilerParams()
145+
133146
del mesh
134147
jax_mesh = None
135148
axis_context = ctx.module_context.axis_context
@@ -141,11 +154,21 @@ def _tc_lowering_rule(
141154
mlir_ctx.load_all_available_dialects()
142155
tpu.register_dialect(mlir_ctx)
143156

157+
match mosaic_params.kernel_type:
158+
case tpu_core.KernelType.TC:
159+
lower_jaxpr_to_module = lowering.lower_jaxpr_to_module
160+
case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE:
161+
lower_jaxpr_to_module = sc_lowering.lower_jaxpr_to_module
162+
case _:
163+
raise ValueError(
164+
f"Unsupported kernel type: {mosaic_params.kernel_type}"
165+
)
166+
144167
def lower_module(for_verification: bool):
145168
if for_verification:
146169
mlir_ctx.allow_unregistered_dialects = True
147170
with mlir_ctx, ir.Location.unknown(mlir_ctx):
148-
return lowering.lower_jaxpr_to_module(
171+
return lower_jaxpr_to_module(
149172
ctx,
150173
grid_mapping,
151174
jaxpr,
@@ -158,6 +181,8 @@ def lower_module(for_verification: bool):
158181

159182
mosaic_module = lower_module(for_verification=False)
160183
if debug:
184+
pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx)
185+
pm.run(mosaic_module.operation)
161186
print(f"\nThe Mosaic module for pallas_call {debug_info.func_src_info}:")
162187
print(mosaic_module)
163188
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
@@ -190,7 +215,7 @@ def lower_module(for_verification: bool):
190215
)
191216
dump_ctx = tempfile.NamedTemporaryFile(
192217
mode="w",
193-
prefix=mlir.sanitize_name(debug_info.func_name) + "-",
218+
prefix=mlir.sanitize_name(name or debug_info.func_name) + "-",
194219
suffix=".pml",
195220
dir=promela_dump_path,
196221
delete=False,
@@ -308,123 +333,3 @@ def _maybe_cast_outputs(*args):
308333

309334
cast_ctx = ctx.replace(avals_in=kernel_out_avals)
310335
return mlir.lower_fun(_maybe_cast_outputs)(cast_ctx, *out_nodes)
311-
312-
313-
# TODO(sharadmv,slebedev): Dedup with _tc_lowering_rule.
314-
def _sc_lowering_rule(
315-
ctx: mlir.LoweringRuleContext,
316-
*in_nodes,
317-
jaxpr: jax_core.Jaxpr,
318-
grid_mapping: pallas_core.GridMapping,
319-
mesh: pallas_core.Mesh | None,
320-
input_output_aliases: tuple[tuple[int, int], ...],
321-
debug: bool,
322-
mosaic_params: tpu_core.CompilerParams,
323-
cost_estimate: pallas_core.CostEstimate | None,
324-
out_avals: tuple[jax_core.AbstractValue, ...],
325-
backend: str | None = None,
326-
metadata: frozen_dict.FrozenDict[str, str] | None,
327-
debug_info: jax_core.DebugInfo,
328-
name: str | None,
329-
):
330-
"""Lowers a pallas_call to a Mosaic SparseCore custom call."""
331-
del mesh, out_avals, backend
332-
out_shapes = grid_mapping.out_shapes
333-
334-
jax_mesh = None
335-
axis_context = ctx.module_context.axis_context
336-
if axis_context is not None:
337-
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
338-
jax_mesh = axis_context.mesh
339-
340-
with mlir.JaxIrContext() as mlir_ctx, ir.Location.unknown(mlir_ctx):
341-
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
342-
mlir_ctx.load_all_available_dialects()
343-
tpu.register_dialect(mlir_ctx)
344-
mosaic_module = sc_lowering.lower_jaxpr_to_module(
345-
ctx, jaxpr, grid_mapping, mosaic_params, mesh=jax_mesh
346-
)
347-
if debug:
348-
pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx)
349-
pm.run(mosaic_module.operation)
350-
print(
351-
"\nThe Mosaic SparseCore module for pallas_call"
352-
f" {debug_info.func_src_info}:"
353-
)
354-
print(mosaic_module)
355-
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
356-
mosaic_cost_estimate = None
357-
if cost_estimate is not None:
358-
mosaic_cost_estimate = cast(
359-
tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate)
360-
)
361-
362-
def _lower_fun(*args):
363-
# Dynamic grid bounds have to go at the front.
364-
out = mosaic.as_tpu_kernel(
365-
mosaic_module,
366-
out_avals,
367-
kernel_name=mlir.sanitize_name(name or debug_info.func_name),
368-
cost_estimate=mosaic_cost_estimate,
369-
input_output_aliases=input_output_aliases,
370-
metadata=metadata,
371-
collective_id=mosaic_params.collective_id,
372-
_ir_version=tpu_custom_call.get_ir_version(ctx),
373-
)(*args)
374-
return out
375-
376-
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *in_nodes)
377-
378-
379-
def pallas_call_tpu_lowering_rule(
380-
ctx: mlir.LoweringRuleContext,
381-
*in_nodes,
382-
jaxpr: jax_core.Jaxpr,
383-
grid_mapping: pallas_core.GridMapping,
384-
mesh: pallas_core.Mesh | None,
385-
input_output_aliases: tuple[tuple[int, int], ...],
386-
debug: bool,
387-
interpret: bool,
388-
compiler_params: dict[str, pallas_core.CompilerParams],
389-
cost_estimate: pallas_core.CostEstimate | None,
390-
out_avals: tuple[jax_core.AbstractValue, ...],
391-
metadata: frozen_dict.FrozenDict[str, str] | None,
392-
name: str | None,
393-
):
394-
"""Lowers a pallas_call to a Mosaic TPU custom call."""
395-
del interpret # Unused.
396-
397-
debug_info = jaxpr.debug_info
398-
if debug:
399-
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
400-
print(jaxpr)
401-
402-
if "mosaic_tpu" in compiler_params:
403-
mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"])
404-
else:
405-
mosaic_params = tpu_core.CompilerParams()
406-
match mosaic_params.kernel_type:
407-
case tpu_core.KernelType.TC:
408-
lowering_rule = _tc_lowering_rule
409-
case (
410-
tpu_core.KernelType.SC_SCALAR_SUBCORE
411-
| tpu_core.KernelType.SC_VECTOR_SUBCORE
412-
):
413-
lowering_rule = _sc_lowering_rule
414-
case _:
415-
raise ValueError(f"Unsupported kernel type: {mosaic_params.kernel_type}")
416-
return lowering_rule(
417-
ctx,
418-
*in_nodes,
419-
jaxpr=jaxpr,
420-
grid_mapping=grid_mapping,
421-
input_output_aliases=input_output_aliases,
422-
mesh=mesh,
423-
debug=debug,
424-
cost_estimate=cost_estimate,
425-
out_avals=out_avals,
426-
metadata=metadata,
427-
mosaic_params=mosaic_params,
428-
debug_info=debug_info,
429-
name=name,
430-
)

jax/_src/pallas/mosaic/sc_lowering.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,20 @@ def dynamic_shape_replacement_fn(x):
9797

9898
def lower_jaxpr_to_module(
9999
lowering_context: mlir.LoweringRuleContext,
100-
jaxpr: jax_core.Jaxpr,
101100
grid_mapping: pallas_core.GridMapping,
102-
mosaic_params: tpu_core.CompilerParams,
101+
jaxpr: jax_core.Jaxpr,
102+
*,
103+
dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None,
104+
kernel_type: tpu_core.KernelType,
103105
mesh: mesh_lib.Mesh | None = None,
106+
for_verification: bool = False,
107+
dynamic_shape_replacement_enabled: bool = False,
104108
) -> ir.Module:
105109
"""Lowers a Jaxpr to a Mosaic SparseCore module."""
106-
dimension_semantics = mosaic_params.dimension_semantics
110+
if dynamic_shape_replacement_enabled:
111+
raise NotImplementedError(
112+
"Dynamic shape replacement is not supported for SparseCore."
113+
)
107114
if not grid_mapping.grid:
108115
index_map_avals, index_map_tree = jax.tree.flatten(
109116
((jax_core.ShapedArray((), jnp.int32),), {})
@@ -167,7 +174,7 @@ def new_index_map(*args, bm=bm):
167174
func_op = lower_jaxpr_to_func(
168175
jaxpr,
169176
name="main",
170-
kernel_type=mosaic_params.kernel_type,
177+
kernel_type=kernel_type,
171178
mosaic_grid_mapping=mosaic_grid_mapping,
172179
forward_compatible=lowering_context.is_forward_compat(),
173180
backend=backend,
@@ -188,8 +195,8 @@ def new_index_map(*args, bm=bm):
188195
bm.block_aval,
189196
name=func_name,
190197
mosaic_grid_mapping=mosaic_grid_mapping,
191-
kernel_type=mosaic_params.kernel_type,
192-
for_verification=False,
198+
kernel_type=kernel_type,
199+
for_verification=for_verification,
193200
forward_compatible=lowering_context.is_forward_compat(),
194201
backend=backend,
195202
)

0 commit comments

Comments
 (0)