diff --git a/CMakeLists.txt b/CMakeLists.txt index ec7bd6c51453..e7fe906f4f2f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -298,6 +298,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/te/*.cc src/autotvm/*.cc src/tir/*.cc + src/s_tir/*.cc src/topi/*.cc src/driver/*.cc src/support/*.cc diff --git a/include/tvm/s_tir/transform.h b/include/tvm/s_tir/transform.h new file mode 100644 index 000000000000..9914c6e49a7f --- /dev/null +++ b/include/tvm/s_tir/transform.h @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/s_tir/transform.h + * \brief S-TIR specific transformation passes. + */ +#ifndef TVM_S_TIR_TRANSFORM_H_ +#define TVM_S_TIR_TRANSFORM_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace s_tir { +namespace transform { + +using tir::transform::CreatePrimFuncPass; +using tvm::transform::Pass; +using tvm::transform::PassContext; + +/*! + * \brief Canonicalize loop to start from zero . + * \return The pass. + */ +TVM_DLL Pass CanonicalizeLoop(); + +/*! + * \brief Lower cross-thread reduction from thread + * bindings to intrinsic function calls. + * \return The pass. + */ +TVM_DLL Pass LowerCrossThreadReduction(); + +/*! + * \brief Lower block init stmt into IfThenElse stmts + * \return The pass. + */ +TVM_DLL Pass LowerInitBlock(); + +/*! + * \brief Locate the buffer allocation to the exact position (usually is + * the lca of buffer access). This pass will inject opaque block + * with alloc_buffers at the allocation site. + * \return The pass. + */ +TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); + +/*! + * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the + * corresponding iter_values in BlockRealize, for opaque blocks by removing all + *. the iter_values in BlockRealize and iter_vars in Block. + * \return The pass. + */ +TVM_DLL Pass ConvertBlocksToOpaque(); + +/*! + * \brief Lift the same thread bindings to their LCA loops + * \return The pass. + */ +TVM_DLL Pass LiftThreadBinding(); + +/*! + * \brief Compact the buffer access region by removing the buffer regions that are not accessed, + * i.e. narrowing the buffer shape and adjust the access region if necessary. + * + * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. + * + * \code + * + * for i in range(0, 16): + * with T.sblock(): + * B = T.alloc_buffer(16, 16) + * for j in range(0, 16): + * B[i, j] = A[i, j] + 1 + * for j in range(0, 16): + * C[i, j] = B[i, j] + 1 + * + * \endcode + * + * This pass narrows the buffer shape and adjust its accessed region accordingly. + * In this particular case, because only a `1 * 16` vector of `B` is accessed, + * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`. + * + * \code + * + * for i in range(0, 16): + * with T.sblock(): + * B = T.alloc_buffer(1, 16) + * for j in range(0, 16): + * B[0, j] = A[i, j] + 1 + * for j in range(0, 16): + * C[i, j] = B[0, j] + 1 + * + * \endcode + * + * \param is_strict ensure the compacted shape always smaller than the original shape. + * otherwise it allows to grow the shape to match actual accessed buffer regions. + * \return The pass. + */ +TVM_DLL Pass CompactBufferAllocation(bool is_strict = true); + +/*! + * \brief Remove match buffers inside the block. Also, it will validate the binding. + * \return The pass. + */ +TVM_DLL Pass LowerMatchBuffer(); + +/*! + * \brief Inject permuted layout for shared memory. + * \return The pass. + */ +TVM_DLL Pass InjectPermutedLayout(); + +/*! + * \brief Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation. + * \return The pass. + */ +TVM_DLL Pass TransformMmaBufferLayout(); + +/*! + * \brief Remove the block to ensure that the TIR can not be scheduled again. + * \return The pass. + */ +TVM_DLL Pass LowerOpaqueBlock(); + +/*! + * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and + * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., + * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the + * unification, we use a consolidated IterVar and a variable for them. + * \return The pass. + * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread` + * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z` + * instead. + */ +TVM_DLL Pass UnifyThreadBinding(); + +/*! + * \brief This pass transforms annotated loops into pipelined ones where producers and consumers + * are overlapped with the information provided in loop annotations, which enables optimization + * techniques like prefetching and pipeline parallelism. + * + * The pipeline scope consists of the direct children of the annotated loop (ignoring SBlockRealize, + * SBlock, SeqStmt), and the number of children is denoted by `n` in the documentation. + * + * The following annotations are used to guide the loop transformation: + * + * 1) Loop annotation `software_pipeline_stage` defines the pipeline stage. + * An array of `n` integers, and each element should be in range [0, max_stage], + * where max_stage is the maximum (inclusive) stage. + * 2) Loop annotation `software_pipeline_order` defines the pipeline order. + * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1]; + * 3) SBlock annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of + * read/write dependency. It's an integer index of the write regions of the block. + * + * Every annotated loop is transformed into a loop with three blocks as its direct children: + * + * 1) Prologue block, where components whose stage is less than `max_stage` is executed; + * + * 2) Body block, where all the components are executed; + * + * 3) Epilogue block, where only components whose stage is greater than 0 will be executed. + * The execution order is controlled by the annotation `software_pipeline_order`, + * and thus could be different than the original order. + * + * Note: For nested software pipelines, the inner software pipeline will be generated first, + * which may affect the number of the direct children of the outer loop. + * In this case, the annotations for the outer software + * pipeline should include the result of the inner software pipeline, + * which is the three blocks as discussed above. + * + * \return The IR transform pass. + */ +TVM_DLL Pass InjectSoftwarePipeline(); + +/*! + * \brief Automatically do memory optimizations for auto copy blocks + * \return The pass. + */ +TVM_DLL Pass LowerAutoCopy(); + +/*! + * \brief Add the explicit local stage for the shared memory access on GPU. + * \return The pass. + */ +TVM_DLL Pass ManifestSharedMemoryLocalStage(); + +/*! \brief Annotate irregular loop mark. */ +TVM_DLL Pass AnnotateIrregularLoop(); + +/*! + * \brief partition loops in the stmt. + * + * \return The pass. + */ +TVM_DLL Pass LoopPartition(); + +/*! + * \brief Inject virtual thread loops. + * + * \return The pass. + */ +TVM_DLL Pass InjectVirtualThread(); + +/*! + * \brief Inject double buffer statements. + * + * \return The pass. + */ +TVM_DLL Pass InjectDoubleBuffer(); + +} // namespace transform +} // namespace s_tir +} // namespace tvm + +#endif // TVM_S_TIR_TRANSFORM_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d2953f1fb48e..bdf8f99aa3b8 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -59,13 +59,6 @@ TVM_DLL Pass CreatePrimFuncPass(std::function required, bool traceable = false); -/*! - * \brief partition loops in the stmt. - * - * \return The pass. - */ -TVM_DLL Pass LoopPartition(); - /*! * \brief Lower vectorization loops. * @@ -75,20 +68,6 @@ TVM_DLL Pass LoopPartition(); */ TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); -/*! - * \brief Inject virtual thread loops. - * - * \return The pass. - */ -TVM_DLL Pass InjectVirtualThread(); - -/*! - * \brief Inject double buffer statements. - * - * \return The pass. - */ -TVM_DLL Pass InjectDoubleBuffer(); - /*! * \brief Rewrite storage allocation pattern. * Moves the allocation to outer most possible scope. @@ -414,105 +393,6 @@ TVM_DLL Pass HoistIfThenElse(); */ TVM_DLL Pass HoistExpression(); -/*! - * \brief Lower cross-thread reduction from thread - * bindings to intrinsic function calls. - * \return The pass. - */ -TVM_DLL Pass LowerCrossThreadReduction(); - -/*! - * \brief Lower block init stmt into IfThenElse stmts - * \return The pass. - */ -TVM_DLL Pass LowerInitBlock(); - -/*! - * \brief Locate the buffer allocation to the exact position (usually is - * the lca of buffer access). This pass will inject opaque block - * with alloc_buffers at the allocation site. - * \return The pass. - */ -TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); - -/*! - * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the - * corresponding iter_values in BlockRealize, for opaque blocks by removing all - *. the iter_values in BlockRealize and iter_vars in Block. - * \return The pass. - */ -TVM_DLL Pass ConvertBlocksToOpaque(); - -/*! - * \brief Lift the same thread bindings to their LCA loops - * \return The pass. - */ -TVM_DLL Pass LiftThreadBinding(); - -/*! - * \brief Compact the buffer access region by removing the buffer regions that are not accessed, - * i.e. narrowing the buffer shape and adjust the access region if necessary. - * - * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. - * - * \code - * - * for i in range(0, 16): - * with T.sblock(): - * B = T.alloc_buffer(16, 16) - * for j in range(0, 16): - * B[i, j] = A[i, j] + 1 - * for j in range(0, 16): - * C[i, j] = B[i, j] + 1 - * - * \endcode - * - * This pass narrows the buffer shape and adjust its accessed region accordingly. - * In this particular case, because only a `1 * 16` vector of `B` is accessed, - * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`. - * - * \code - * - * for i in range(0, 16): - * with T.sblock(): - * B = T.alloc_buffer(1, 16) - * for j in range(0, 16): - * B[0, j] = A[i, j] + 1 - * for j in range(0, 16): - * C[i, j] = B[0, j] + 1 - * - * \endcode - * - * \param is_strict ensure the compacted shape always smaller than the original shape. - * otherwise it allows to grow the shape to match actual accessed buffer regions. - * \return The pass. - */ -TVM_DLL Pass CompactBufferAllocation(bool is_strict = true); - -/*! - * \brief Remove match buffers inside the block. Also, it will validate the binding. - * \return The pass. - */ -TVM_DLL Pass LowerMatchBuffer(); - -/*! - * \brief Inject permuted layout for shared memory. - * \return The pass. - */ -TVM_DLL Pass InjectPermutedLayout(); - -/*! - * \brief Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation. - * \return The pass. - */ -TVM_DLL Pass TransformMmaBufferLayout(); - -/*! - * \brief Remove the block to ensure that the TIR can not be scheduled again. - * \return The pass. - */ -TVM_DLL Pass LowerOpaqueBlock(); - /*! * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional * BufferLoad/BufferStore for the TIR not contains opaque block. @@ -541,18 +421,6 @@ TVM_DLL Pass LowerAsyncDMA(); */ TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false); -/*! - * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and - * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., - * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the - * unification, we use a consolidated IterVar and a variable for them. - * \return The pass. - * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread` - * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z` - * instead. - */ -TVM_DLL Pass UnifyThreadBinding(); - /*! * A pass to merge multiple TIR-level shared memory allocations into one */ @@ -575,107 +443,6 @@ TVM_DLL Pass ConvertForLoopsToSerial(); */ TVM_DLL Pass UnifiedStaticMemoryPlanner(); -/*! - * \brief This pass transforms annotated loops into pipelined ones where producers and consumers - * are overlapped with the information provided in loop annotations, which enables optimization - * techniques like prefetching and pipeline parallelism. - * - * The pipeline scope consists of the direct children of the annotated loop (ignoring SBlockRealize, - * SBlock, SeqStmt), and the number of children is denoted by `n` in the documentation. - * - * The following annotations are used to guide the loop transformation: - * - * 1) Loop annotation `software_pipeline_stage` defines the pipeline stage. - * An array of `n` integers, and each element should be in range [0, max_stage], - * where max_stage is the maximum (inclusive) stage. - * 2) Loop annotation `software_pipeline_order` defines the pipeline order. - * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1]; - * 3) SBlock annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of - * read/write dependency. It's an integer index of the write regions of the block. - * - * Every annotated loop is transformed into a loop with three blocks as its direct children: - * - * 1) Prologue block, where components whose stage is less than `max_stage` is executed; - * - * 2) Body block, where all the components are executed; - * - * 3) Epilogue block, where only components whose stage is greater than 0 will be executed. - * The execution order is controlled by the annotation `software_pipeline_order`, - * and thus could be different than the original order. - * - * Note: For nested software pipelines, the inner software pipeline will be generated first, - * which may affect the number of the direct children of the outer loop. - * In this case, the annotations for the outer software - * pipeline should include the result of the inner software pipeline, - * which is the three blocks as discussed above. - * Example: - * - * Before this pass, the TIR is: - * - * \code{.py} - * @T.prim_func - * def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: - * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - * for i in T.serial(0, 16, - * annotations={"software_pipeline_stage": [0, 1], - * "software_pipeline_order": [0, 1]} - * ): - * with T.sblock(): - * T.reads(A[tx, i]) - * T.writes(C[tx, i]) - * B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - * with T.sblock("B"): - * T.reads(A[tx, i]) - * T.writes(B[tx, 0]) - * B[tx, 0] = A[tx, i] * T.float32(2) - * with T.sblock("C"): - * T.reads(B[tx, 0]) - * T.writes(C[tx, i]) - * C[tx, i] = B[tx, 0] + T.float32(1) - * \endcode - * - * The TIR above annotates the loop as a two-stage pipeline with no reordering. - * After applying this pass, the TIR is transformed into: - * - * \code{.py} - * @T.prim_func - * def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: - * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - * with T.sblock(): - * T.reads([A[tx, 0:16]]) - * T.writes([C[tx, 0:16]]) - * B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - * with T.sblock("prologue"): - * T.reads([A[tx, 0]]) - * T.writes([B[0, tx, 0]]) - * B[0, tx, 0] = A[tx, 0] * T.float32(2) - * with T.sblock("body"): - * T.reads([A[tx, 1:16], B[0:2, tx, 0]]) - * T.writes([B[0:2, tx, 0], C[tx, 0:15]]) - * for i in T.serial(0, 15): - * with T.sblock("B"): - * T.reads([A[tx, i + 1]]) - * T.writes([B[(i + 1) % 2, tx, 0]]) - * B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) - * with T.sblock("C"): - * T.reads([B[i % 2, tx, 0]]) - * T.writes([C[tx, i]]) - * C[tx, i] = B[i % 2, tx, 0] + T.float32(1) - * with T.sblock("epilogue"): - * T.reads([B[1, tx, 0]]) - * T.writes([C[tx, 15]]) - * C[tx, 15] = B[1, tx, 0] + T.float32(1) - * \endcode - * - * The original loop has two blocks, B and C, as its direct children. The loop annotations indicate - * that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B - * should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order - * of block B and C inside the body block inside the result TIR. - * - * \return The IR transform pass. - */ -TVM_DLL Pass InjectSoftwarePipeline(); - TVM_DLL Pass BindParams(const ffi::Array& constants); /*! @@ -685,12 +452,6 @@ TVM_DLL Pass BindParams(const ffi::Array& constants); */ TVM_DLL Pass ExtractPrimFuncConstants(); -/*! - * \brief Automatically do memory optimizations for auto copy blocks - * \return The pass. - */ -TVM_DLL Pass LowerAutoCopy(); - /*! * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) * \return The pass. @@ -741,12 +502,6 @@ TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true); */ TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite = false); -/*! - * \brief Add the explicit local stage for the shared memory access on GPU. - * \return The pass. - */ -TVM_DLL Pass ManifestSharedMemoryLocalStage(); - /*! * \brief Insert intrinsic calls to instrument function and loop level profiling. * \return The pass. diff --git a/python/tvm/s_tir/__init__.py b/python/tvm/s_tir/__init__.py index 287575c85b70..72246bee88b9 100644 --- a/python/tvm/s_tir/__init__.py +++ b/python/tvm/s_tir/__init__.py @@ -21,6 +21,7 @@ from . import backend from . import pipeline +from . import transform from . import schedule from .schedule import StmtSRef, SBlockScope, ScheduleState, Schedule, ScheduleError, Trace from .block_dependence_info import SBlockDependenceInfo diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index fe81261da7e5..e895025bc447 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -19,7 +19,7 @@ """The TIR backend compilation pipeline for Adreno""" import tvm -from tvm import tir +from tvm import tir, s_tir from tvm.tir import pipeline as tir_pipeline @@ -33,31 +33,31 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I config = pass_ctx.config passes = [ tir.backend.adreno.transform.TextureFlatten(), - tir.transform.CanonicalizeLoop(), - tir.transform.LowerCrossThreadReduction(), - tir.transform.LowerInitBlock(), - tir.transform.PlanAndUpdateBufferAllocationLocation(), - tir.transform.ConvertBlocksToOpaque(), - tir.transform.LiftThreadBinding(), - tir.transform.ManifestSharedMemoryLocalStage(), - tir.transform.CompactBufferAllocation(), - tir.transform.LowerAutoCopy(), - tir.transform.UnifyThreadBinding(), - tir.transform.LowerMatchBuffer(), + s_tir.transform.CanonicalizeLoop(), + s_tir.transform.LowerCrossThreadReduction(), + s_tir.transform.LowerInitBlock(), + s_tir.transform.PlanAndUpdateBufferAllocationLocation(), + s_tir.transform.ConvertBlocksToOpaque(), + s_tir.transform.LiftThreadBinding(), + s_tir.transform.ManifestSharedMemoryLocalStage(), + s_tir.transform.CompactBufferAllocation(), + s_tir.transform.LowerAutoCopy(), + s_tir.transform.UnifyThreadBinding(), + s_tir.transform.LowerMatchBuffer(), tir.transform.Simplify(), - tir.transform.InjectPermutedLayout(), - tir.transform.AnnotateIrregularLoop(), - tir.transform.InjectSoftwarePipeline(), - tir.transform.TransformMmaBufferLayout(), - tir.transform.LowerOpaqueBlock(), + s_tir.transform.InjectPermutedLayout(), + s_tir.transform.AnnotateIrregularLoop(), + s_tir.transform.InjectSoftwarePipeline(), + s_tir.transform.TransformMmaBufferLayout(), + s_tir.transform.LowerOpaqueBlock(), tir.backend.adreno.transform.InjectTextureAlloc(), tir.transform.FlattenBuffer(), tir.transform.BF16ComputeLegalize(), tir.transform.NarrowDataType(32), - tir.transform.LoopPartition(), + s_tir.transform.LoopPartition(), tir.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), - tir.transform.InjectVirtualThread(), - tir.transform.InjectDoubleBuffer(), + s_tir.transform.InjectVirtualThread(), + s_tir.transform.InjectDoubleBuffer(), ] if not bool(config.get("tir.disable_storage_rewrite", False)): passes.append(tir.transform.StorageRewrite()) diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index 01bb0e3cc7d9..59cec5b5827f 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -19,7 +19,7 @@ """The S-TIR backend compilation pipeline.""" import tvm -from tvm import tir +from tvm import tir, s_tir from tvm.tir import pipeline as tir_pipeline @@ -32,30 +32,30 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I pass_ctx = tvm.transform.PassContext.current() config = pass_ctx.config passes = [ - tir.transform.CanonicalizeLoop(), - tir.transform.LowerCrossThreadReduction(), - tir.transform.LowerInitBlock(), - tir.transform.PlanAndUpdateBufferAllocationLocation(), - tir.transform.ConvertBlocksToOpaque(), - tir.transform.LiftThreadBinding(), - tir.transform.ManifestSharedMemoryLocalStage(), - tir.transform.CompactBufferAllocation(), - tir.transform.LowerAutoCopy(), - tir.transform.UnifyThreadBinding(), - tir.transform.LowerMatchBuffer(), + s_tir.transform.CanonicalizeLoop(), + s_tir.transform.LowerCrossThreadReduction(), + s_tir.transform.LowerInitBlock(), + s_tir.transform.PlanAndUpdateBufferAllocationLocation(), + s_tir.transform.ConvertBlocksToOpaque(), + s_tir.transform.LiftThreadBinding(), + s_tir.transform.ManifestSharedMemoryLocalStage(), + s_tir.transform.CompactBufferAllocation(), + s_tir.transform.LowerAutoCopy(), + s_tir.transform.UnifyThreadBinding(), + s_tir.transform.LowerMatchBuffer(), tir.transform.Simplify(), - tir.transform.InjectPermutedLayout(), - tir.transform.AnnotateIrregularLoop(), - tir.transform.InjectSoftwarePipeline(), - tir.transform.TransformMmaBufferLayout(), - tir.transform.LowerOpaqueBlock(), + s_tir.transform.InjectPermutedLayout(), + s_tir.transform.AnnotateIrregularLoop(), + s_tir.transform.InjectSoftwarePipeline(), + s_tir.transform.TransformMmaBufferLayout(), + s_tir.transform.LowerOpaqueBlock(), tir.transform.FlattenBuffer(), tir.transform.BF16ComputeLegalize(), tir.transform.NarrowDataType(32), - tir.transform.LoopPartition(), + s_tir.transform.LoopPartition(), tir.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), - tir.transform.InjectVirtualThread(), - tir.transform.InjectDoubleBuffer(), + s_tir.transform.InjectVirtualThread(), + s_tir.transform.InjectDoubleBuffer(), ] if not bool(config.get("tir.disable_storage_rewrite", False)): passes.append(tir.transform.StorageRewrite()) diff --git a/python/tvm/s_tir/transform/__init__.py b/python/tvm/s_tir/transform/__init__.py new file mode 100644 index 000000000000..4529684dc2d8 --- /dev/null +++ b/python/tvm/s_tir/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Namespace of all S-TIR transformations""" +# pylint: disable=wildcard-import, invalid-name + +from .transform import * diff --git a/python/tvm/s_tir/transform/_ffi_api.py b/python/tvm/s_tir/transform/_ffi_api.py new file mode 100644 index 000000000000..f884d3d370c5 --- /dev/null +++ b/python/tvm/s_tir/transform/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.transform""" +import tvm_ffi + + +tvm_ffi.init_ffi_api("s_tir.transform", __name__) diff --git a/python/tvm/s_tir/transform/transform.py b/python/tvm/s_tir/transform/transform.py new file mode 100644 index 000000000000..d4dbb8ee86c7 --- /dev/null +++ b/python/tvm/s_tir/transform/transform.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""S-TIR specific transformations.""" +# pylint: disable=invalid-name, unsupported-binary-operation + +from . import _ffi_api +from ... import ir as _ir +from ... import ffi as _ffi + + +def CanonicalizeLoop(): + """Canonicalize the loop to start from zero and use trivial step + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CanonicalizeLoop() # type: ignore + + +def LowerCrossThreadReduction(): + """Lower cross-thread reduction from thread bindings to + intrinsic function calls. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerCrossThreadReduction() # type: ignore + + +def LowerInitBlock(): + """Lower block init stmt into IfThenElse statements. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerInitBlock() # type: ignore + + +def PlanAndUpdateBufferAllocationLocation(): + """Locate the buffer allocation to the exact position (usually is + the lca of buffer access). This pass will inject opaque block + with alloc_buffers at the allocation site. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore + + +def ConvertBlocksToOpaque(): + """Substitute all the block vars with the PrimExprs they are bound to, indicated by + the corresponding iter_values in BlockRealize, and then convert the blocks into + opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ConvertBlocksToOpaque() # type: ignore + + +def LiftThreadBinding(): + """Lift the same thread bindings to their LCA loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LiftThreadBinding() # type: ignore + + +def CompactBufferAllocation(is_strict: bool = True): + """Compact the buffer access region by removing the buffer regions + that are not accessed, i.e. narrowing the buffer shape and adjust + the access region if necessary. + + Parameters + ---------- + is_strict : bool + Ensure the compacted shape to be always smaller than the original shape. + Otherwise it allows to grow the shape to match actual accessed buffer regions. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CompactBufferAllocation(is_strict) # type: ignore + + +def LowerMatchBuffer(): + """Remove match buffers inside the block. Also, it will validate the binding. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerMatchBuffer() # type: ignore + + +def LowerOpaqueBlock(): + """Remove the block to ensure that the TIR can not be scheduled again. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerOpaqueBlock() # type: ignore + + +def TransformMmaBufferLayout(): + """Transform mma buffer layout + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.TransformMmaBufferLayout() # type: ignore + + +def InjectPermutedLayout(): + """Inject permuted layout in mma + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectPermutedLayout() # type: ignore + + +def UnifyThreadBinding(): + """Unify all the thread bindings for "blockIdx.x/y/z", + "threadIdx.x/y/z", and "vthread.x/y/z". + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UnifyThreadBinding() # type: ignore + + +def InjectSoftwarePipeline(): + """Transform annotated loops into pipelined one that parallelize producers and consumers + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectSoftwarePipeline() # type: ignore + + +def LowerAutoCopy(): + """Automatically do memory optimizations for auto copy blocks + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerAutoCopy() # type: ignore + + +def ManifestSharedMemoryLocalStage(): + """Add the explicit local stage for the shared memory access on GPU. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ManifestSharedMemoryLocalStage() # type: ignore + + +def AnnotateIrregularLoop(): + """Annotate irregular loop mark. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateIrregularLoop() # type: ignore + + +@_ffi.register_object("s_tir.transform.LoopPartitionConfig") +class LoopPartitionConfig(_ir.Attrs): + """Config for loop partition pass""" + + +def LoopPartition(): + """Partition loops in the stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LoopPartition() # type: ignore + + +def InjectVirtualThread(): + """Inject virtual thread loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectVirtualThread() # type: ignore + + +@_ffi.register_object("s_tir.transform.InjectDoubleBufferConfig") +class InjectDoubleBufferConfig(_ir.Attrs): + """Config for inject double buffer pass""" + + +def InjectDoubleBuffer(): + """Inject double buffer statements. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectDoubleBuffer() # type: ignore diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 86d79dc6badd..bc33ab97e63f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -50,22 +50,6 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore -@_ffi.register_object("tir.transform.LoopPartitionConfig") -class LoopPartitionConfig(_ir.Attrs): - """Config for loop partition pass""" - - -def LoopPartition(): - """Inject virtual thread loops. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LoopPartition() # type: ignore - - def VectorizeLoop(enable_vectorize: bool = True): """Lower vectorization loops. @@ -83,33 +67,6 @@ def VectorizeLoop(enable_vectorize: bool = True): return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore -def InjectVirtualThread(): - """Inject virtual thread loops. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.InjectVirtualThread() # type: ignore - - -@_ffi.register_object("tir.transform.InjectDoubleBufferConfig") -class InjectDoubleBufferConfig(_ir.Attrs): - """Config for inject double buffer pass""" - - -def InjectDoubleBuffer(): - """Inject double buffer statements. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.InjectDoubleBuffer() # type: ignore - - def InjectRollingBuffer(): """Inject rolling buffer statements. @@ -430,19 +387,6 @@ def AnnotateDeviceRegions(): return _ffi_api.AnnotateDeviceRegions() # type: ignore -def AnnotateIrregularLoop(): - """Annotate irregular loop mark. Loop transformations like - peeling, partition, unroll, etc is not allowed on irregular - loop with internal loop continuation and breaks. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.AnnotateIrregularLoop() # type: ignore - - def SplitHostDevice(): """Split the function into a host function and device functions. @@ -760,139 +704,6 @@ def HoistExpression(): return _ffi_api.HoistExpression() # type: ignore -def LowerCrossThreadReduction(): - """Lower cross-thread reduction from thread bindings to - intrinsic function calls. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerCrossThreadReduction() # type: ignore - - -def LowerInitBlock(): - """Lower block init stmt into IfThenElse statements. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerInitBlock() # type: ignore - - -def PlanAndUpdateBufferAllocationLocation(): - """Locate the buffer allocation to the exact position (usually is - the lca of buffer access). This pass will inject opaque block - with alloc_buffers at the allocation site. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore - - -def ConvertBlocksToOpaque(): - """Substitute all the block vars with the PrimExprs they are bound to, indicated by - the corresponding iter_values in BlockRealize, and then convert the blocks into - opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.ConvertBlocksToOpaque() # type: ignore - - -def LiftThreadBinding(): - """Lift the same thread bindings to their LCA loops. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LiftThreadBinding() # type: ignore - - -def CompactBufferAllocation(is_strict: bool = True): - """Compact the buffer access region. by removing the buffer regions - that are not accessed, i.e. narrowing the buffer shape and adjust - the access region if necessary. - - Example - ------- - - Before narrowing, ``B`` is a ``[16, 16]`` buffer, but only a - skinny vector ``B[i, 0:16]`` is accessed. - - .. code-block:: python - - for i in range(0, 16): - with T.sblock(): - B = T.alloc_buffer(16, 16) - for j in range(0, 16): - B[i, j] = A[i, j] + 1 - for j in range(0, 16): - C[i, j] = B[i, j] + 1 - - This pass narrows the buffer shape and adjust its accessed region - accordingly. In this particular case, because only a ``1 * 16`` - vector of ``B`` is accessed, the pass narrows ``B`` to shape ``[1, - 16]``, and changes the access to ``B[i, j]`` to ``B[0, j]``. - - .. code-block:: python - - for i in range(0, 16): - with T.sblock(): - B = T.alloc_buffer(1, 16) - for j in range(0, 16): - B[0, j] = A[i, j] + 1 - for j in range(0, 16): - C[i, j] = B[0, j] + 1 - - Parameters - ---------- - is_strict : bool - Ensure the compacted shape to be always smaller than the original shape. - Otherwise it allows to grow the shape to match actual accessed buffer regions. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - - """ - return _ffi_api.CompactBufferAllocation(is_strict) # type: ignore - - -def LowerMatchBuffer(): - """Remove match buffers inside the block. Also, it will validate the binding. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerMatchBuffer() # type: ignore - - -def LowerOpaqueBlock(): - """Remove the block to ensure that the TIR can not be scheduled again. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerOpaqueBlock() # type: ignore - - def FlattenBuffer(): """Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block. @@ -905,50 +716,6 @@ def FlattenBuffer(): return _ffi_api.FlattenBuffer() # type: ignore -def TransformMmaBufferLayout(): - """Transform mma buffer layout - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.TransformMmaBufferLayout() # type: ignore - - -def InjectPermutedLayout(): - """Inject permuted layout in mma - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.InjectPermutedLayout() # type: ignore - - -def UnifyThreadBinding(): - """Unify all the thread bindings for "blockIdx.x/y/z", - "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, - two vars that are bound to a thread axis (e.g., "threadIdx.x") - use different IterVars and variables in their AttrStmts. After - the unification, we use a consolidated IterVar and a variable - for them. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - - Note - ---- - `vthread` is a legacy behavior that will be deprecated, though - thread bindings of `vthread` are still also unified in this - pass. Please use `vthread.x`, `vthread.y` and `vthread.z` instead. - """ - return _ffi_api.UnifyThreadBinding() # type: ignore - - def MergeSharedMemoryAllocations(): """This pass merges multiple TIR-level shared memory allocations into one allocation. @@ -972,17 +739,6 @@ def ConvertForLoopsToSerial(): return _ffi_api.ConvertForLoopsToSerial() # type: ignore -def InjectSoftwarePipeline(): - """Transform annotated loops into pipelined one that parallelize producers and consumers - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.InjectSoftwarePipeline() # type: ignore - - def ExtractPrimFuncConstants(): """Collects and unificates tir non-scalar constants to module's attr 'Constants' array. @@ -994,17 +750,6 @@ def ExtractPrimFuncConstants(): return _ffi_api.ExtractPrimFuncConstants() # type: ignore -def LowerAutoCopy(): - """Automatically do memory optimizations for auto copy blocks - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerAutoCopy() # type: ignore - - def RenormalizeSplitPattern(): """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) @@ -1087,17 +832,6 @@ def RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite=False): return _ffi_api.RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite) # type: ignore -def ManifestSharedMemoryLocalStage(): - """Add the explicit local stage for the shared memory access on GPU. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.ManifestSharedMemoryLocalStage() # type: ignore - - def InstrumentProfileIntrinsics(): """Insert intrinsic calls to instrument function and loop level profiling. @@ -1171,14 +905,3 @@ def LowerVtcmAlloc(): The result pass """ return _ffi_api.LowerVtcmAlloc() # type: ignore - - -def CanonicalizeLoop(): - """Canonicalize the loop to start from zero and use trivial step - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.CanonicalizeLoop() # type: ignore diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index aed8e21ffa42..9df517588215 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -311,15 +312,15 @@ Sequential PassListForPerStoreFeature() { return Sequential({ tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_tensor_rewrite*/ true), tir::transform::SimplifyForFeatureExtraction(), - tir::transform::LowerCrossThreadReduction(), - tir::transform::LowerInitBlock(), - tir::transform::PlanAndUpdateBufferAllocationLocation(), - tir::transform::ConvertBlocksToOpaque(), - tir::transform::CompactBufferAllocation(), + s_tir::transform::LowerCrossThreadReduction(), + s_tir::transform::LowerInitBlock(), + s_tir::transform::PlanAndUpdateBufferAllocationLocation(), + s_tir::transform::ConvertBlocksToOpaque(), + s_tir::transform::CompactBufferAllocation(), tir::transform::Simplify(), - tir::transform::LowerAutoCopy(), - tir::transform::UnifyThreadBinding(), - tir::transform::LowerMatchBuffer(), + s_tir::transform::LowerAutoCopy(), + s_tir::transform::UnifyThreadBinding(), + s_tir::transform::LowerMatchBuffer(), tir::transform::Simplify(), }); } diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 9f59404de5ef..956e5ddcb5a6 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "../utils.h" @@ -137,19 +138,19 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { try { auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::BindTarget(this->target)); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); + pass_list.push_back(s_tir::transform::LowerInitBlock()); + pass_list.push_back(s_tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(s_tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(s_tir::transform::CompactBufferAllocation()); + pass_list.push_back(s_tir::transform::LowerMatchBuffer()); + pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::InjectVirtualThread()); - pass_list.push_back(tir::transform::InjectDoubleBuffer()); + pass_list.push_back(s_tir::transform::InjectVirtualThread()); + pass_list.push_back(s_tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::StorageRewrite()); tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index bdcb1af1fe41..f1ff28b071ff 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include "../utils.h" @@ -154,27 +155,27 @@ class VerifyGPUCodeNode : public PostprocNode { try { auto pass_list = ffi::Array(); // Phase 1 - pass_list.push_back(tir::transform::LowerCrossThreadReduction()); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::LiftThreadBinding()); - pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(s_tir::transform::LowerCrossThreadReduction()); + pass_list.push_back(s_tir::transform::LowerInitBlock()); + pass_list.push_back(s_tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(s_tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(s_tir::transform::LiftThreadBinding()); + pass_list.push_back(s_tir::transform::ManifestSharedMemoryLocalStage()); + pass_list.push_back(s_tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::LowerAutoCopy()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); + pass_list.push_back(s_tir::transform::LowerAutoCopy()); + pass_list.push_back(s_tir::transform::UnifyThreadBinding()); + pass_list.push_back(s_tir::transform::LowerMatchBuffer()); + pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); // Phase 2 pass_list.push_back(tir::transform::VectorizeLoop(true)); - pass_list.push_back(tir::transform::InjectVirtualThread()); - pass_list.push_back(tir::transform::InjectDoubleBuffer()); + pass_list.push_back(s_tir::transform::InjectVirtualThread()); + pass_list.push_back(s_tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); pass_list.push_back(tir::transform::LowerIntrin()); diff --git a/src/tir/transforms/annotate_irregular_loop.cc b/src/s_tir/transform/annotate_irregular_loop.cc similarity index 82% rename from src/tir/transforms/annotate_irregular_loop.cc rename to src/s_tir/transform/annotate_irregular_loop.cc index c715922d60b3..76c41a25b612 100644 --- a/src/tir/transforms/annotate_irregular_loop.cc +++ b/src/s_tir/transform/annotate_irregular_loop.cc @@ -20,13 +20,14 @@ #include #include #include +#include #include #include #include -#include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class IrregularLoopAnnotator : public StmtMutator { public: @@ -42,12 +43,13 @@ class IrregularLoopAnnotator : public StmtMutator { if (has_jump_) { CHECK(op->kind == ForKind::kSerial) << "Loop kind " << op->kind << " is invalid for irregular loop " << op->loop_var; - for (const char* key : {attr::pragma_auto_unroll_max_step, attr::pragma_unroll_explicit, - attr::pragma_loop_partition_hint, attr::software_pipeline_stage}) { + for (const char* key : + {tir::attr::pragma_auto_unroll_max_step, tir::attr::pragma_unroll_explicit, + tir::attr::pragma_loop_partition_hint, tir::attr::software_pipeline_stage}) { CHECK(!res->annotations.count(key)) << "Annotation `" << key << "` is invalid for irregular loop " << op->loop_var; } - res.CopyOnWrite()->annotations.Set(attr::irregular_loop_mark, 1); + res.CopyOnWrite()->annotations.Set(tir::attr::irregular_loop_mark, 1); } std::swap(cur_has_jump, has_jump_); return res; @@ -81,14 +83,14 @@ Pass AnnotateIrregularLoop() { return func; }; - return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateIrregularLoop", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.AnnotateIrregularLoop", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.AnnotateIrregularLoop", AnnotateIrregularLoop); + refl::GlobalDef().def("s_tir.transform.AnnotateIrregularLoop", AnnotateIrregularLoop); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/canonicalize_loop.cc b/src/s_tir/transform/canonicalize_loop.cc similarity index 80% rename from src/tir/transforms/canonicalize_loop.cc rename to src/s_tir/transform/canonicalize_loop.cc index 93511bf84bb2..2a0b09a7e151 100644 --- a/src/tir/transforms/canonicalize_loop.cc +++ b/src/s_tir/transform/canonicalize_loop.cc @@ -24,15 +24,17 @@ #include #include #include +#include #include #include #include -#include #include namespace tvm { -namespace tir { +namespace s_tir { + +using namespace tvm::tir; class LoopCanonicalizer : public StmtExprMutator { public: @@ -43,12 +45,11 @@ class LoopCanonicalizer : public StmtExprMutator { if (is_zero(op->min) && op->HasTrivialStep()) { return StmtExprMutator::VisitStmt_(op); } - arith::Analyzer analyzer; const auto* loop_var = op->loop_var.get(); PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1)); // report warning for negative step, since it would be a forever loop - if (!analyzer.CanProveGreaterEqual(step, 1)) { + if (!analyzer_.CanProveGreaterEqual(step, 1)) { // TODO(tvm): prove dynamic shaped step LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: " << step; } @@ -57,7 +58,7 @@ class LoopCanonicalizer : public StmtExprMutator { auto n = CopyOnWrite(op); n->body = VisitStmt(op->body); n->min = make_zero(loop_var->dtype); - n->extent = analyzer.Simplify(ceildiv(op->extent, step)); + n->extent = analyzer_.Simplify(ceildiv(op->extent, step)); n->step = std::nullopt; new_iter_info_.erase(loop_var); return For(n); @@ -72,31 +73,29 @@ class LoopCanonicalizer : public StmtExprMutator { return ffi::GetRef(op); } + private: + arith::Analyzer analyzer_; /*! \brief Map iter variable `x` to `x * stride + offset`. */ std::unordered_map> new_iter_info_; }; -PrimFunc CanonicalizeLoop(PrimFunc func) { - PrimFuncNode* fptr = func.CopyOnWrite(); - fptr->body = LoopCanonicalizer()(func->body); - return func; -} - namespace transform { Pass CanonicalizeLoop() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return CanonicalizeLoop(std::move(f)); + auto pass_func = [=](PrimFunc func, IRModule m, PassContext ctx) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = LoopCanonicalizer()(std::move(fptr->body)); + return func; }; - return CreatePrimFuncPass(pass_func, 0, "tir.CanonicalizeLoop", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.CanonicalizeLoop", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.CanonicalizeLoop", CanonicalizeLoop); + refl::GlobalDef().def("s_tir.transform.CanonicalizeLoop", CanonicalizeLoop); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc similarity index 96% rename from src/tir/transforms/compact_buffer_region.cc rename to src/s_tir/transform/compact_buffer_region.cc index cc73121b5cdf..8b9f71f4d93f 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -25,9 +25,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -35,11 +35,12 @@ #include "../../support/arena.h" #include "../../support/nd_int_set.h" #include "../../support/utils.h" -#include "../schedule/utils.h" -#include "ir_utils.h" +#include "../../tir/schedule/utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using support::NDIntSet; @@ -259,8 +260,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } }; - record_explicit_region(attr::explicit_read_region, BufferIndexType::kRead); - record_explicit_region(attr::explicit_write_region, BufferIndexType::kWrite); + record_explicit_region(tir::attr::explicit_read_region, BufferIndexType::kRead); + record_explicit_region(tir::attr::explicit_write_region, BufferIndexType::kWrite); // Step 3. Record relax position of ancestor_loops_ for (const Buffer& buffer : op->alloc_buffers) { @@ -318,7 +319,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iter = Downcast(op->node); ancestor_iters_.push_back(iter); Range dom = iter->dom; @@ -740,28 +741,24 @@ Stmt BufferCompactorCompact( return stmt; } -PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) { - PrimFuncNode* fptr = f.CopyOnWrite(); - auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict); - auto storage_align = CollectStorageAlignAnnotation(f->body); - fptr->body = BufferCompactorCompact(f, region, storage_align); - return f; -} - namespace transform { Pass CompactBufferAllocation(bool is_strict) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return CompactBufferAllocation(std::move(f), is_strict); + PrimFuncNode* fptr = f.CopyOnWrite(); + auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict); + auto storage_align = CollectStorageAlignAnnotation(f->body); + fptr->body = BufferCompactorCompact(f, region, storage_align); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.CompactBufferAllocation", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.CompactBufferAllocation", CompactBufferAllocation); + refl::GlobalDef().def("s_tir.transform.CompactBufferAllocation", CompactBufferAllocation); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/s_tir/transform/convert_blocks_to_opaque.cc similarity index 90% rename from src/tir/transforms/convert_blocks_to_opaque.cc rename to src/s_tir/transform/convert_blocks_to_opaque.cc index 546de79085d6..c799cc87cb3c 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/s_tir/transform/convert_blocks_to_opaque.cc @@ -23,13 +23,14 @@ */ #include +#include #include -#include -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque @@ -108,26 +109,22 @@ class OpaqueBlockConverter : public StmtExprMutator { std::unordered_set forbidden_iter_vars_; }; -PrimFunc ConvertBlocksToOpaque(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = OpaqueBlockConverter::Substitute(f); - return f; -} - namespace transform { Pass ConvertBlocksToOpaque() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return ConvertBlocksToOpaque(std::move(f)); + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockConverter::Substitute(f); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.ConvertBlocksToOpaque", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.ConvertBlocksToOpaque", ConvertBlocksToOpaque); + refl::GlobalDef().def("s_tir.transform.ConvertBlocksToOpaque", ConvertBlocksToOpaque); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/s_tir/transform/inject_double_buffer.cc similarity index 93% rename from src/tir/transforms/inject_double_buffer.cc rename to src/s_tir/transform/inject_double_buffer.cc index e874dc0564cf..ff77de10aeee 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/s_tir/transform/inject_double_buffer.cc @@ -23,14 +23,15 @@ */ #include #include +#include #include #include -#include -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; struct InjectDoubleBufferConfigNode : public AttrsNodeReflAdapter { int split_loop; @@ -41,7 +42,7 @@ struct InjectDoubleBufferConfigNode : public AttrsNodeReflAdapterattr_key == attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { touched_.insert(op->node.as()); StmtExprVisitor::VisitStmt_(op); } else { @@ -79,7 +80,7 @@ class DoubleBufferDetector : public StmtExprVisitor { class StripDoubleBufferWrite : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == tir::attr::double_buffer_write) { return VisitStmt(op->body); } else { return StmtMutator::VisitStmt_(op); @@ -102,7 +103,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { return MakeProducer(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -278,7 +279,7 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[e.loop->loop_var.get()] = loop_shift; vmap[e.switch_write_var.get()] = indexmod(loop_shift, two); body = Substitute(body, vmap); - body = AttrStmt(buffer, attr::double_buffer_write, 1, body); + body = AttrStmt(buffer, tir::attr::double_buffer_write, 1, body); body = IfThenElse(loop_shift < e.loop->extent, body); return body; } @@ -316,22 +317,22 @@ namespace transform { Pass InjectDoubleBuffer() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto cfg = ctx->GetConfig("tir.InjectDoubleBuffer"); + auto cfg = ctx->GetConfig("s_tir.InjectDoubleBuffer"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } n->body = DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectDoubleBuffer", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.InjectDoubleBuffer", InjectDoubleBuffer); + refl::GlobalDef().def("s_tir.transform.InjectDoubleBuffer", InjectDoubleBuffer); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/s_tir/transform/inject_permuted_layout.cc similarity index 97% rename from src/tir/transforms/inject_permuted_layout.cc rename to src/s_tir/transform/inject_permuted_layout.cc index 5bd3fb29a88f..6483f28119d2 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/s_tir/transform/inject_permuted_layout.cc @@ -23,18 +23,19 @@ */ #include #include +#include #include #include #include -#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/utils.h" -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using namespace arith; using namespace runtime; @@ -294,15 +295,15 @@ Pass InjectPermutedLayout() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return PermutedLayoutInjector::Transform(std::move(f)); }; - return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectPermutedLayout", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.InjectPermutedLayout", InjectPermutedLayout); + refl::GlobalDef().def("s_tir.transform.InjectPermutedLayout", InjectPermutedLayout); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc similarity index 98% rename from src/tir/transforms/inject_software_pipeline.cc rename to src/s_tir/transform/inject_software_pipeline.cc index ab6d0c12d628..fbcece7ff2a7 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -22,18 +22,19 @@ * \brief Transform annotated loops into pipelined one that parallelize producers and consumers */ #include +#include #include #include -#include #include #include "../../support/utils.h" -#include "../schedule/utils.h" -#include "./ir_utils.h" +#include "../../tir/schedule/utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; namespace software_pipeline { @@ -1131,9 +1132,9 @@ class PipelineInjector : private StmtExprMutator { } auto pipeline_stages = - Downcast>(op->annotations.at(attr::software_pipeline_stage)); + Downcast>(op->annotations.at(tir::attr::software_pipeline_stage)); auto pipeline_orders = - Downcast>(op->annotations.at(attr::software_pipeline_order)); + Downcast>(op->annotations.at(tir::attr::software_pipeline_order)); CHECK_EQ(pipeline_stages.size(), original_order.size()) << "PrimFunc " << global_symbol_ << " has original order " << original_order.Map([](const auto& block) { return block->name_hint; }) @@ -1144,7 +1145,7 @@ class PipelineInjector : private StmtExprMutator { << ", but pipeline annotation is " << pipeline_orders << " with different size"; std::unordered_set pipeline_async_stages; - if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) { + if (auto annot = op->annotations.Get(tir::attr::software_pipeline_async_stages)) { for (auto s : Downcast>(annot.value())) { pipeline_async_stages.insert(s->value); } @@ -1153,8 +1154,9 @@ class PipelineInjector : private StmtExprMutator { ffi::Map preserved_annotations; for (const auto& kv : op->annotations) { const ffi::String& key = kv.first; - if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order && - kv.first != attr::software_pipeline_async_stages) { + if (kv.first != tir::attr::software_pipeline_stage && + kv.first != tir::attr::software_pipeline_order && + kv.first != tir::attr::software_pipeline_async_stages) { preserved_annotations.Set(key, kv.second); } } @@ -1206,7 +1208,7 @@ class PipelineInjector : private StmtExprMutator { buffer_data_to_buffer_.Set(buffer->data, buffer); } - auto it = op->annotations.find(attr::double_buffer_scope); + auto it = op->annotations.find(tir::attr::double_buffer_scope); if (it != op->annotations.end()) { int buffer_index = Downcast((*it).second).IntValue(); CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size()) @@ -1223,8 +1225,8 @@ class PipelineInjector : private StmtExprMutator { } bool HasPipelineAnnotation(const ForNode* op) const { - auto it1 = op->annotations.find(attr::software_pipeline_stage); - auto it2 = op->annotations.find(attr::software_pipeline_order); + auto it1 = op->annotations.find(tir::attr::software_pipeline_stage); + auto it2 = op->annotations.find(tir::attr::software_pipeline_order); bool has_stage = it1 != op->annotations.end(); bool has_order = it2 != op->annotations.end(); if (has_stage && has_order) { @@ -1260,15 +1262,15 @@ Pass InjectSoftwarePipeline() { fptr->body = ConvertSSA(std::move(fptr->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectSoftwarePipeline", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.InjectSoftwarePipeline", InjectSoftwarePipeline); + refl::GlobalDef().def("s_tir.transform.InjectSoftwarePipeline", InjectSoftwarePipeline); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc similarity index 97% rename from src/tir/transforms/inject_virtual_thread.cc rename to src/s_tir/transform/inject_virtual_thread.cc index cd7283a7ef4d..5f305907eb23 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -22,18 +22,19 @@ */ #include #include +#include #include #include #include -#include #include #include "../../arith/ir_mutator_with_analyzer.h" -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; // If expression is touched by var. class ExprTouched final : public StmtExprVisitor { @@ -289,7 +290,8 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(ffi::GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && - (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { + (op->attr_key == tir::attr::coproc_uop_scope || + op->attr_key == tir::attr::coproc_scope)) { return InjectVTLoop(ffi::GetRef(op), true); } else { Stmt body = this->VisitStmt(op->body); @@ -495,7 +497,7 @@ class VirtualThreadInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); - if (op->attr_key == attr::virtual_thread) { + if (op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread"; int nthread = static_cast(op->value.as()->value); @@ -521,15 +523,15 @@ Pass InjectVirtualThread() { n->body = ConvertSSA(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectVirtualThread", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.InjectVirtualThread", InjectVirtualThread); + refl::GlobalDef().def("s_tir.transform.InjectVirtualThread", InjectVirtualThread); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/s_tir/transform/lift_thread_binding.cc similarity index 93% rename from src/tir/transforms/lift_thread_binding.cc rename to src/s_tir/transform/lift_thread_binding.cc index 45bbf4af52de..537737b1288b 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/s_tir/transform/lift_thread_binding.cc @@ -23,14 +23,15 @@ */ #include +#include #include -#include #include "../../runtime/thread_storage_scope.h" -#include "./ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; std::pair>>, ObjectPtrHash, ObjectPtrEqual>, @@ -169,26 +170,22 @@ class ThreadBindingLifter : public StmtExprMutator { ffi::Map var_subst; }; -PrimFunc LiftThreadBinding(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = ThreadBindingLifter()(std::move(fptr->body)); - return f; -} - namespace transform { Pass LiftThreadBinding() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return LiftThreadBinding(std::move(f)); + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingLifter()(std::move(fptr->body)); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LiftThreadBinding", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LiftThreadBinding", LiftThreadBinding); + refl::GlobalDef().def("s_tir.transform.LiftThreadBinding", LiftThreadBinding); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/loop_partition.cc b/src/s_tir/transform/loop_partition.cc similarity index 96% rename from src/tir/transforms/loop_partition.cc rename to src/s_tir/transform/loop_partition.cc index fd9bd2d6531c..e21f650b1ff7 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -24,11 +24,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include @@ -36,10 +36,11 @@ #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; struct LoopPartitionConfigNode : public AttrsNodeReflAdapter { bool partition_const_loop; @@ -59,7 +60,7 @@ struct LoopPartitionConfigNode : public AttrsNodeReflAdapterattr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { const IterVarNode* iv = op->node.as(); ICHECK(iv); Var var = iv->var; @@ -151,7 +152,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.erase(var.get()); return; } - } else if (op->attr_key == attr::pragma_loop_partition_hint) { + } else if (op->attr_key == tir::attr::pragma_loop_partition_hint) { if (analyzer_.CanProve(op->value)) { const VarNode* var = nullptr; if (op->node.as()) { @@ -253,7 +254,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // handle thread_axis - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); ICHECK(thread_axis); const VarNode* var = thread_axis->var.get(); @@ -381,7 +382,7 @@ class ThreadPartitionInserter : public StmtMutator { : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { innermost_thread_scope_ = true; Stmt stmt = StmtMutator::VisitStmt_(op); // add branch code inside the innermost thread scope @@ -436,7 +437,7 @@ class LoopPartitioner : public StmtMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key != attr::thread_extent) { + if (op->attr_key != tir::attr::thread_extent) { return StmtMutator::VisitStmt_(op); } @@ -790,7 +791,7 @@ class RemoveLikelyTagsAndHints : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::pragma_loop_partition_hint) { + if (op->attr_key == tir::attr::pragma_loop_partition_hint) { return VisitStmt(op->body); } return StmtExprMutator::VisitStmt_(op); @@ -811,24 +812,24 @@ namespace transform { Pass LoopPartition() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto cfg = ctx->GetConfig("tir.LoopPartition"); + auto cfg = ctx->GetConfig("s_tir.LoopPartition"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } - n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop, - cfg.value()->no_unroll_loop_with_extent_one, - cfg.value()->unroll_loop_with_partition_hint_no_interval); + n->body = s_tir::LoopPartition(std::move(n->body), cfg.value()->partition_const_loop, + cfg.value()->no_unroll_loop_with_extent_one, + cfg.value()->unroll_loop_with_partition_hint_no_interval); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LoopPartition", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LoopPartition", LoopPartition); + refl::GlobalDef().def("s_tir.transform.LoopPartition", LoopPartition); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc similarity index 98% rename from src/tir/transforms/lower_cross_thread_reduction.cc rename to src/s_tir/transform/lower_cross_thread_reduction.cc index fb9ffb24db6c..7338b0887970 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -22,18 +22,19 @@ */ #include #include +#include #include #include -#include #include "../../runtime/thread_storage_scope.h" #include "../../support/utils.h" -#include "../schedule/analysis.h" -#include "./ir_utils.h" +#include "../../tir/schedule/analysis.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using runtime::ThreadScope; using support::StartsWith; @@ -927,27 +928,23 @@ class CrossThreadReductionTransformer : public StmtMutator { crt_buf2threads_; }; -PrimFunc LowerCrossThreadReduction(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = CrossThreadReductionTransformer()(f->body); - return f; -} - namespace transform { Pass LowerCrossThreadReduction() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return LowerCrossThreadReduction(std::move(f)); + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = CrossThreadReductionTransformer()(fptr->body); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerCrossThreadReduction", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerCrossThreadReduction", LowerCrossThreadReduction); + refl::GlobalDef().def("s_tir.transform.LowerCrossThreadReduction", LowerCrossThreadReduction); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/lower_init_block.cc b/src/s_tir/transform/lower_init_block.cc similarity index 84% rename from src/tir/transforms/lower_init_block.cc rename to src/s_tir/transform/lower_init_block.cc index 3ccaa7cea75f..3e5bbc1652a3 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/s_tir/transform/lower_init_block.cc @@ -22,14 +22,15 @@ * \file lower_reduction.cc */ #include +#include #include #include -#include -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class InitBlockLower : public StmtMutator { private: @@ -65,27 +66,23 @@ class InitBlockLower : public StmtMutator { } }; -PrimFunc LowerInitBlock(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - fptr->body = InitBlockLower()(std::move(fptr->body)); - return func; -} - namespace transform { Pass LowerInitBlock() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - return LowerInitBlock(std::move(f)); + auto fptr = f.CopyOnWrite(); + fptr->body = InitBlockLower()(std::move(fptr->body)); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerInitBlock", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerInitBlock", LowerInitBlock); + refl::GlobalDef().def("s_tir.transform.LowerInitBlock", LowerInitBlock); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/s_tir/transform/lower_match_buffer.cc similarity index 95% rename from src/tir/transforms/lower_match_buffer.cc rename to src/s_tir/transform/lower_match_buffer.cc index b426f60a450e..6939047bd177 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/s_tir/transform/lower_match_buffer.cc @@ -24,16 +24,17 @@ #include #include +#include #include #include #include -#include -#include "../ir/functor_common.h" -#include "ir_utils.h" +#include "../../tir/ir/functor_common.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class MatchBufferLower : public StmtExprMutator { public: explicit MatchBufferLower(const PrimFunc& func) { @@ -260,27 +261,23 @@ class MatchBufferLower : public StmtExprMutator { arith::Analyzer analyzer_; }; -PrimFunc LowerMatchBuffer(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - fptr->body = MatchBufferLower(func)(std::move(fptr->body)); - return func; -} - namespace transform { Pass LowerMatchBuffer() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - return LowerMatchBuffer(std::move(f)); + auto fptr = f.CopyOnWrite(); + fptr->body = MatchBufferLower(f)(std::move(fptr->body)); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerMatchBuffer", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerMatchBuffer", LowerMatchBuffer); + refl::GlobalDef().def("s_tir.transform.LowerMatchBuffer", LowerMatchBuffer); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/s_tir/transform/lower_opaque_block.cc similarity index 91% rename from src/tir/transforms/lower_opaque_block.cc rename to src/s_tir/transform/lower_opaque_block.cc index b5d6f35eb8bc..39c76e9ce5a3 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/s_tir/transform/lower_opaque_block.cc @@ -22,13 +22,14 @@ */ #include +#include #include -#include -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Remove SBlock to ensure that the TIR can not be scheduled again. @@ -67,7 +68,7 @@ class OpaqueBlockLower : public StmtExprMutator { tuple.Set<0>(-1); allocate_aligns.push_back(tuple); } - allocate_annotations.Set(attr::buffer_dim_align, allocate_aligns); + allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns); } body = Allocate(buffer->data, buffer->dtype, allocation_shape, const_true(), std::move(body), @@ -105,7 +106,7 @@ class OpaqueBlockLower : public StmtExprMutator { ffi::String thread_tag = op->thread_binding.value()->thread_tag; body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); } else if (is_one(extent) && op->annotations.empty() && - !op->annotations.count(attr::irregular_loop_mark)) { + !op->annotations.count(tir::attr::irregular_loop_mark)) { // Case 2. Unit loop return body; } else { @@ -143,8 +144,8 @@ class OpaqueBlockLower : public StmtExprMutator { /*thread_tag=*/thread_tag); ffi::String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || thread_tag == "vthread.y" || thread_tag == "vthread.z") - ? attr::virtual_thread - : attr::thread_extent; + ? tir::attr::virtual_thread + : tir::attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), @@ -181,7 +182,7 @@ class OpaqueBlockLower : public StmtExprMutator { pragma_attrs->clear(); for (const auto& kv : annotations) { const ffi::String& key = kv.first; - if (attr::IsPragmaKey(key)) { + if (tir::attr::IsPragmaKey(key)) { pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); } else if (!is_block) { // the loop annotation is preserved @@ -203,26 +204,22 @@ class OpaqueBlockLower : public StmtExprMutator { std::unordered_map storage_align_; }; -PrimFunc LowerOpaqueBlock(PrimFunc f) { - auto fptr = f.CopyOnWrite(); - fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body)); - return f; -} - namespace transform { Pass LowerOpaqueBlock() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return LowerOpaqueBlock(std::move(f)); + auto fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body)); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerOpaqueBlock", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerOpaqueBlock", LowerOpaqueBlock); + refl::GlobalDef().def("s_tir.transform.LowerOpaqueBlock", LowerOpaqueBlock); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/s_tir/transform/manifest_shared_memory_local_stage.cc similarity index 96% rename from src/tir/transforms/manifest_shared_memory_local_stage.cc rename to src/s_tir/transform/manifest_shared_memory_local_stage.cc index 4addb7823bda..4d8ae4471952 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/s_tir/transform/manifest_shared_memory_local_stage.cc @@ -28,19 +28,20 @@ */ #include #include +#include #include #include #include -#include #include #include "../../runtime/thread_storage_scope.h" -#include "../schedule/transform.h" +#include "../../tir/schedule/transform.h" #include "tvm/tir/stmt.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage * to store the result. Rewrite the original block to load from the intermediate buffer. @@ -186,7 +187,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator { } Stmt VisitStmt_(const SBlockNode* op) final { - if (op->annotations.count(attr::manifest_shared_memory_local_stage)) { + if (op->annotations.count(tir::attr::manifest_shared_memory_local_stage)) { // Rewrite the shared memory access to load from the intermediate buffer. // The annotated block must be a leaf block (will be checked during rewriting). No need to // visit its body recursively. @@ -195,7 +196,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator { auto [target_buffer, new_buffer, new_block, local_stage] = rewriter.Rewrite(op); buffer_remap_.Set(target_buffer, new_buffer); - new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage); + new_block.CopyOnWrite()->annotations.erase(tir::attr::manifest_shared_memory_local_stage); buffer_local_stage_.Set(target_buffer, local_stage); target_buffers_.push_back(target_buffer); @@ -274,15 +275,15 @@ Pass ManifestSharedMemoryLocalStage() { n->body = SharedMemoryLocalStageInserter()(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.ManifestSharedMemoryLocalStage", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.ManifestSharedMemoryLocalStage", + refl::GlobalDef().def("s_tir.transform.ManifestSharedMemoryLocalStage", ManifestSharedMemoryLocalStage); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/s_tir/transform/memhammer_coalesce.cc similarity index 99% rename from src/tir/transforms/memhammer_coalesce.cc rename to src/s_tir/transform/memhammer_coalesce.cc index 0d5b27044232..a575fd3e9626 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/s_tir/transform/memhammer_coalesce.cc @@ -20,7 +20,8 @@ #include "./memhammer_rewrite_rule.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Fuse consecutive loops @@ -232,5 +233,5 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, } return ret; } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/s_tir/transform/memhammer_intermediate_stage.cc similarity index 99% rename from src/tir/transforms/memhammer_intermediate_stage.cc rename to src/s_tir/transform/memhammer_intermediate_stage.cc index d4826e609319..d9116ac6553e 100644 --- a/src/tir/transforms/memhammer_intermediate_stage.cc +++ b/src/s_tir/transform/memhammer_intermediate_stage.cc @@ -19,7 +19,8 @@ #include "memhammer_rewrite_rule.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, Stmt* ith_loop = nullptr) { @@ -447,5 +448,5 @@ Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraint return after_caching; } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc similarity index 98% rename from src/tir/transforms/memhammer_lower_auto_copy.cc rename to src/s_tir/transform/memhammer_lower_auto_copy.cc index f4dc6579cd0b..98bb90156569 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -20,23 +20,24 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include "../../runtime/thread_storage_scope.h" -#include "../schedule/utils.h" -#include "./ir_utils.h" +#include "../../tir/schedule/utils.h" +#include "../../tir/transforms/ir_utils.h" #include "./memhammer_rewrite_rule.h" #include "tvm/tir/stmt.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using support::NDIntSet; @@ -775,14 +776,14 @@ Pass LowerAutoCopy() { n->body = mutator.RewritePaddingBody(n->body); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerAutoCopy", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerAutoCopy", LowerAutoCopy); + refl::GlobalDef().def("s_tir.transform.LowerAutoCopy", LowerAutoCopy); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/s_tir/transform/memhammer_rewrite_rule.h similarity index 97% rename from src/tir/transforms/memhammer_rewrite_rule.h rename to src/s_tir/transform/memhammer_rewrite_rule.h index 5751aa119e36..974db627c4a8 100644 --- a/src/tir/transforms/memhammer_rewrite_rule.h +++ b/src/s_tir/transform/memhammer_rewrite_rule.h @@ -16,23 +16,24 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ -#define TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ +#ifndef TVM_S_TIR_TRANSFORM_MEMHAMMER_REWRITE_RULE_H_ +#define TVM_S_TIR_TRANSFORM_MEMHAMMER_REWRITE_RULE_H_ #include #include +#include #include #include #include #include -#include #include -#include "../schedule/utils.h" +#include "../../tir/schedule/utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief The set containing all possible constraints of a data copy */ struct ConstraintSet { @@ -252,7 +253,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::S ffi::Optional compute_location, const ffi::Array& outer_loops, Buffer* alloc_buffer); -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ +#endif // TVM_S_TIR_TRANSFORM_MEMHAMMER_REWRITE_RULE_H_ diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc similarity index 99% rename from src/tir/transforms/memhammer_tensorcore_rewrite.cc rename to src/s_tir/transform/memhammer_tensorcore_rewrite.cc index 4c03c155db1a..2285b3843618 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc @@ -20,7 +20,8 @@ #include "./memhammer_rewrite_rule.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. @@ -562,5 +563,5 @@ Stmt MmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, return rewriter(body); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/s_tir/transform/plan_update_buffer_allocation_location.cc similarity index 94% rename from src/tir/transforms/plan_update_buffer_allocation_location.cc rename to src/s_tir/transform/plan_update_buffer_allocation_location.cc index 65bd05975b67..265d9574f8fe 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/s_tir/transform/plan_update_buffer_allocation_location.cc @@ -23,15 +23,16 @@ */ #include +#include #include #include -#include #include -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class CollectManagedAllocations : public StmtExprVisitor { public: @@ -243,29 +244,25 @@ class BufferAllocationLocator : public StmtExprMutator { std::unordered_set managed_allocations_; }; -PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - BufferAllocationLocator locator(func); - fptr->body = locator(fptr->body); - return func; -} - namespace transform { Pass PlanAndUpdateBufferAllocationLocation() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return PlanAndUpdateBufferAllocationLocation(std::move(f)); + auto fptr = f.CopyOnWrite(); + BufferAllocationLocator locator(f); + fptr->body = locator(fptr->body); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.PlanAndUpdateBufferAllocationLocation", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.PlanAndUpdateBufferAllocationLocation", + refl::GlobalDef().def("s_tir.transform.PlanAndUpdateBufferAllocationLocation", PlanAndUpdateBufferAllocationLocation); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/s_tir/transform/transform_mma_buffer_layout.cc similarity index 96% rename from src/tir/transforms/transform_mma_buffer_layout.cc rename to src/s_tir/transform/transform_mma_buffer_layout.cc index 31e249394524..27d87f4a80ea 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/s_tir/transform/transform_mma_buffer_layout.cc @@ -19,15 +19,16 @@ #include #include +#include #include #include #include -#include -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Rewriter for all m16n8k8.matrix[A/B/C] buffer. This pass mainly do two things: @@ -184,14 +185,14 @@ Pass TransformMmaBufferLayout() { n->body = MmaBufferLayoutTransformer()(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.TransformMmaBufferLayout", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.TransformMmaBufferLayout", TransformMmaBufferLayout); + refl::GlobalDef().def("s_tir.transform.TransformMmaBufferLayout", TransformMmaBufferLayout); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/s_tir/transform/unify_thread_binding.cc similarity index 93% rename from src/tir/transforms/unify_thread_binding.cc rename to src/s_tir/transform/unify_thread_binding.cc index 502acd5a467e..e0e6ec82ad06 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/s_tir/transform/unify_thread_binding.cc @@ -23,15 +23,16 @@ #include #include +#include #include #include -#include #include "../../support/utils.h" -#include "ir_utils.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using support::StartsWith; @@ -47,7 +48,7 @@ class ThreadBindingUnifier : public StmtExprMutator { private: Stmt VisitStmt_(const AttrStmtNode* op) final { // If this AttrStmt is not thread binding attribute, return as usual. - if (op->attr_key != attr::thread_extent && op->attr_key != attr::virtual_thread) { + if (op->attr_key != tir::attr::thread_extent && op->attr_key != tir::attr::virtual_thread) { return StmtMutator::VisitStmt_(op); } IterVar old_iter_var = Downcast(op->node); @@ -188,27 +189,23 @@ class ThreadBindingUnifier : public StmtExprMutator { arith::Analyzer ana; }; -PrimFunc UnifyThreadBinding(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); - return f; -} - namespace transform { Pass UnifyThreadBinding() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return UnifyThreadBinding(std::move(f)); + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); + return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); + return CreatePrimFuncPass(pass_func, 0, "s_tir.UnifyThreadBinding", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.UnifyThreadBinding", UnifyThreadBinding); + refl::GlobalDef().def("s_tir.transform.UnifyThreadBinding", UnifyThreadBinding); } } // namespace transform -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 557f42c5ba10..1741eff9375e 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -148,13 +149,13 @@ int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { ffi::Array GetVTCMCompactionPasses() { auto pass_list = ffi::Array(); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); + pass_list.push_back(s_tir::transform::LowerInitBlock()); + pass_list.push_back(s_tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(s_tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(s_tir::transform::CompactBufferAllocation()); + pass_list.push_back(s_tir::transform::LowerMatchBuffer()); + pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::VectorizeLoop(true)); diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py index 329dfac35d45..108e212ed351 100644 --- a/tests/python/codegen/test_target_codegen.py +++ b/tests/python/codegen/test_target_codegen.py @@ -99,7 +99,7 @@ def test_loop_step( for i in T.serial(3, 1024, step=96): C[i] = A[i] + B[i] - with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]): + with tvm.transform.PassContext(disabled_pass=["s_tir.CanonicalizeLoop"]): lib = tvm.compile(test_loop_step, target=target) src = lib.mod.inspect_source() diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 433a7ed0e2e0..e689eabf15f5 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -1049,7 +1049,7 @@ def cuda_loop_step( C[i] = A[i] + B[i] target = tvm.target.Target({"kind": "cuda"}) - with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]): + with tvm.transform.PassContext(disabled_pass=["s_tir.CanonicalizeLoop"]): lib = tvm.compile(cuda_loop_step, target=target) cuda_src = lib.mod.imports[0].inspect_source() diff --git a/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py b/tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py similarity index 93% rename from tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py rename to tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py index 9f2815919630..f70ea151cb65 100644 --- a/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tir, s_tir from tvm.script import ir as I, tir as T @@ -52,14 +52,14 @@ def expected(A: T.Buffer((10,), "int32")): A[0] = A[0] + 1 mod = tvm.IRModule.from_expr(before) - mod = tvm.tir.transform.AnnotateIrregularLoop()(mod) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.AnnotateIrregularLoop()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) tvm.ir.assert_structural_equal(mod["before"].with_attr("global_symbol", "expected"), expected) def test_annotate_loop_with_break(): """Test that loops containing break statements are annotated as irregular.""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: @@ -85,7 +85,7 @@ def main(A: T.Buffer((10,), "int32")): def test_annotate_loop_with_continue(): """Test that loops containing continue statements are annotated as irregular.""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: @@ -111,7 +111,7 @@ def main(A: T.Buffer((10,), "int32")): def test_nested_irregular_both_loops(): """Test nested loops where both loops have break/continue.""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: @@ -143,7 +143,7 @@ def main(A: T.Buffer((10, 10), "int32")): def test_while_loop_with_break(): """Test that while loops with break/continue are not annotated (while loops don't have annotations).""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: @@ -173,7 +173,7 @@ def main(A: T.Buffer((10,), "int32")): def test_break_in_nested_conditional(): """Test break statement deeply nested in conditional blocks.""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: @@ -203,7 +203,7 @@ def main(A: T.Buffer((10,), "int32"), flag1: T.int32, flag2: T.int32): def test_while_loop_with_break_standalone(): """Test that while loops with break/continue are not annotated (while loops don't have annotations).""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: @@ -233,7 +233,7 @@ def main(A: T.Buffer((10,), "int32")): def test_nested_irregular_loop_standalone(): """Test deeply nested loops with irregular control flow only in innermost loop.""" - transform = tir.transform.AnnotateIrregularLoop() + transform = s_tir.transform.AnnotateIrregularLoop() @I.ir_module class Before: diff --git a/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py b/tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py similarity index 76% rename from tests/python/tir-transform/test_tir_transform_canonicalize_loop.py rename to tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py index 6f6d88137c20..9b356f64f735 100644 --- a/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py @@ -16,72 +16,72 @@ # under the License. import pytest import tvm -from tvm import tir +from tvm import tir, s_tir from tvm.script import tir as T def test_canonicalize_loop(): @T.prim_func - def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): T.func_attr({"global_symbol": "main"}) for i in range(1, 128, 5): B[i] = A[i] + 1.0 @T.prim_func - def expected(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + def expected(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): T.func_attr({"global_symbol": "main"}) for i in T.serial(0, 26): B[i * 5 + 1] = A[i * 5 + 1] + 1.0 mod = tvm.IRModule.from_expr(before) - mod = tir.transform.CanonicalizeLoop()(mod) + mod = s_tir.transform.CanonicalizeLoop()(mod) tvm.ir.assert_structural_equal(mod["main"], expected) def test_canonicalize_nested_loop(): @T.prim_func - def before(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "main"}) for i in range(1, 128, 5): for j in range(2, 128, 3): B[i, j] = A[i, j] + 1.0 @T.prim_func - def expected(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "main"}) for i in T.serial(0, 26): for j in T.serial(0, 42): B[i * 5 + 1, j * 3 + 2] = A[i * 5 + 1, j * 3 + 2] + 1.0 mod = tvm.IRModule.from_expr(before) - mod = tir.transform.CanonicalizeLoop()(mod) + mod = s_tir.transform.CanonicalizeLoop()(mod) tvm.ir.assert_structural_equal(mod["main"], expected) def test_canonicalize_negative_step(): @T.prim_func - def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): T.func_attr({"global_symbol": "main"}) for i in T.serial(0, 127, step=-3): B[i] = A[i] + 1.0 mod = tvm.IRModule.from_expr(before) with pytest.raises(tvm.error.InternalError): - mod = tir.transform.CanonicalizeLoop()(mod) + mod = s_tir.transform.CanonicalizeLoop()(mod) def test_canonicalize_dynamic_step(): """Currently we report error for dynamic step since we could not prove it is positive""" @T.prim_func - def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"], step: T.int32): + def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), step: T.int32): T.func_attr({"global_symbol": "main"}) for i in T.serial(0, 128, step=step): B[i] = A[i] + 1.0 mod = tvm.IRModule.from_expr(before) with pytest.raises(tvm.error.InternalError): - mod = tir.transform.CanonicalizeLoop()(mod) + mod = s_tir.transform.CanonicalizeLoop()(mod) if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_compact_buffer_region.py rename to tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py index a7f93f83a4d3..05ddfa4b67f0 100644 --- a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm import tir +from tvm import tir, s_tir from tvm.script import tir as T @@ -37,7 +37,7 @@ def test_compact(self): before = tvm.IRModule.from_expr(self.before.with_attr("global_symbol", "main")) expected = tvm.IRModule.from_expr(self.expected.with_attr("global_symbol", "main")) simplify = tvm.transform.Sequential([tir.transform.Simplify(), tir.transform.RemoveNoOp()]) - after = simplify(tir.transform.CompactBufferAllocation(is_strict=is_strict)(before)) + after = simplify(s_tir.transform.CompactBufferAllocation(is_strict=is_strict)(before)) expected = simplify(expected) try: tvm.ir.assert_structural_equal(after, expected) @@ -51,12 +51,12 @@ def test_compact(self): if not is_lower_order_free: return - lower_before_compact = tir.transform.LowerOpaqueBlock()(before) - lower_before_compact = tir.transform.CompactBufferAllocation(is_strict=is_strict)( + lower_before_compact = s_tir.transform.LowerOpaqueBlock()(before) + lower_before_compact = s_tir.transform.CompactBufferAllocation(is_strict=is_strict)( lower_before_compact ) lower_before_compact = simplify(lower_before_compact) - lower_after_compact = tir.transform.LowerOpaqueBlock()(after) + lower_after_compact = s_tir.transform.LowerOpaqueBlock()(after) lower_after_compact = simplify(lower_after_compact) try: tvm.ir.assert_structural_equal(lower_before_compact, lower_after_compact) @@ -1061,10 +1061,10 @@ def main( # Get partitioned workload to compact mod = tvm.IRModule.from_expr(main) with tvm.transform.PassContext( - config={"tir.LoopPartition": {"partition_const_loop": True}} + config={"s_tir.LoopPartition": {"partition_const_loop": True}} ): - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LoopPartition()(mod) return mod["main"] diff --git a/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py similarity index 95% rename from tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py rename to tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py index 1647aeefc93d..565ba87462f1 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py @@ -18,14 +18,14 @@ import tvm import tvm.testing -from tvm import tir, te +from tvm import tir, te, s_tir from tvm.script import ir as I, tir as T def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.s_tir.transform.ConvertBlocksToOpaque()(mod) mod = tvm.tir.transform.Simplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) @@ -88,7 +88,7 @@ def main(A: T.Buffer(8, "int32")): T.evaluate(0) with pytest.raises(tvm.TVMError): - tvm.tir.transform.ConvertBlocksToOpaque()(Before) + tvm.s_tir.transform.ConvertBlocksToOpaque()(Before) if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_inject_double_buffer.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py similarity index 95% rename from tests/python/tir-transform/test_tir_transform_inject_double_buffer.py rename to tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py index 3c9047679462..d99981a9565b 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_double_buffer.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py @@ -44,10 +44,10 @@ def db(A: T.handle("float32"), C: T.handle("float32")): mod = Module opt = tvm.transform.Sequential( - [tvm.tir.transform.InjectDoubleBuffer(), tvm.tir.transform.Simplify()] + [tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tir.transform.Simplify()] ) - with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}): + with tvm.transform.PassContext(config={"s_tir.InjectDoubleBuffer": {"split_loop": 2}}): mod = opt(mod) stmt = mod["db"].body @@ -77,7 +77,7 @@ def count_sync(op): def test_double_buffer_transform(): transform = tvm.ir.transform.Sequential( [ - tvm.tir.transform.InjectDoubleBuffer(), + tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tir.transform.Simplify(), ] ) @@ -129,7 +129,7 @@ def test_double_buffer_with_decl_buffer(): transform = tvm.ir.transform.Sequential( [ - tvm.tir.transform.InjectDoubleBuffer(), + tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tir.transform.Simplify(), ] ) diff --git a/tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py rename to tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py index 0c8013984e18..690ac86d00e0 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py @@ -16,6 +16,7 @@ # under the License. import tvm import tvm.testing +import tvm.s_tir from tvm import IRModule from tvm.script import tir as T from tvm.tir import PrimFunc @@ -23,7 +24,7 @@ def _check_primfunc_transform(before: PrimFunc, expected: PrimFunc): before_module = IRModule.from_expr(before) - after_module = tvm.tir.transform.InjectPermutedLayout()(before_module) + after_module = tvm.s_tir.transform.InjectPermutedLayout()(before_module) after = after_module["before"].without_attr("global_symbol") expected = expected.without_attr("global_symbol") diff --git a/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py similarity index 98% rename from tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py rename to tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py index 6bcacc63a499..a2334b9bd27f 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py @@ -38,7 +38,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + mod = tvm.s_tir.transform.InjectSoftwarePipeline()(mod) mod = tvm.tir.transform.Simplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True @@ -48,7 +48,7 @@ def _check(original, transformed): def _check_error(func): mod = tvm.IRModule.from_expr(func) with pytest.raises(ValueError): - tvm.tir.transform.InjectSoftwarePipeline()(mod) + tvm.s_tir.transform.InjectSoftwarePipeline()(mod) @T.prim_func @@ -1187,7 +1187,7 @@ def test_simple_compute_async(): _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) - mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + mod = tvm.s_tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): @@ -1234,7 +1234,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) - mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + mod = tvm.s_tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: @@ -1321,7 +1321,7 @@ def simple_compute( sch.annotate(loop, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) sch.annotate(loop, ann_key="software_pipeline_order", ann_val=[0, 2, 1]) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) - mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + mod = tvm.s_tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func def ref( @@ -1401,7 +1401,7 @@ def test_three_stage_compute_two_stage_async(): _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0, 1]) - mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + mod = tvm.s_tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> None: @@ -1557,11 +1557,11 @@ def test_async_pipelined_mma_gemm_simple(): seq = tvm.transform.Sequential( [ - tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), - tvm.tir.transform.ConvertBlocksToOpaque(), - tvm.tir.transform.UnifyThreadBinding(), - tvm.tir.transform.LowerMatchBuffer(), - tvm.tir.transform.InjectSoftwarePipeline(), + tvm.s_tir.transform.PlanAndUpdateBufferAllocationLocation(), + tvm.s_tir.transform.ConvertBlocksToOpaque(), + tvm.s_tir.transform.UnifyThreadBinding(), + tvm.s_tir.transform.LowerMatchBuffer(), + tvm.s_tir.transform.InjectSoftwarePipeline(), ] ) mod = seq(sch.mod) @@ -1602,11 +1602,11 @@ def test_async_nested_pipeline_mma_gemm_ideal_annotation(): seq = tvm.transform.Sequential( [ - tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), - tvm.tir.transform.ConvertBlocksToOpaque(), - tvm.tir.transform.UnifyThreadBinding(), - tvm.tir.transform.LowerMatchBuffer(), - tvm.tir.transform.InjectSoftwarePipeline(), + tvm.s_tir.transform.PlanAndUpdateBufferAllocationLocation(), + tvm.s_tir.transform.ConvertBlocksToOpaque(), + tvm.s_tir.transform.UnifyThreadBinding(), + tvm.s_tir.transform.LowerMatchBuffer(), + tvm.s_tir.transform.InjectSoftwarePipeline(), ] ) mod = seq(sch.mod) diff --git a/tests/python/tir-transform/test_tir_transform_inject_virtual_thread.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py similarity index 95% rename from tests/python/tir-transform/test_tir_transform_inject_virtual_thread.py rename to tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py index a6d1d4b58d7b..9c9bf8d888c4 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py @@ -50,7 +50,7 @@ def main(A: T.handle("float32"), C: T.handle("float32")): # For vthread, expected allocation is m * nthread B_expected_alloc = m * nthread - stmt = tvm.tir.transform.InjectVirtualThread()(Module)["main"] + stmt = tvm.s_tir.transform.InjectVirtualThread()(Module)["main"] # Find allocate nodes allocates = [] @@ -102,7 +102,7 @@ def main(): # C expected allocation is m * nthread * nthread (used in extern with both vthreads) C_expected_alloc = m * nthread * nthread - stmt = tvm.tir.transform.InjectVirtualThread()(Module)["main"] + stmt = tvm.s_tir.transform.InjectVirtualThread()(Module)["main"] # Find allocate nodes allocates = [] @@ -139,7 +139,7 @@ def main(A: T.handle("float32")): if i == 0: B[i] = A_buf[i * nthread + vt] + T.float32(2) - stmt = tvm.tir.transform.InjectVirtualThread()(Module)["main"] + stmt = tvm.s_tir.transform.InjectVirtualThread()(Module)["main"] # Find IfThenElse nodes if_nodes = [] @@ -183,7 +183,7 @@ def expected_func(): B[T.Mul(3, 4) : T.Mul(3, 4) + 4] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func.with_attr("global_symbol", "main")) - after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) + after_mod = tvm.s_tir.transform.InjectVirtualThread()(before_mod) after_func = after_mod["main"] tvm.ir.assert_structural_equal(after_func, expected_func.with_attr("global_symbol", "main")) @@ -210,7 +210,7 @@ def expected_func(): B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func.with_attr("global_symbol", "main")) - intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) + intermediate_mod = tvm.s_tir.transform.InjectVirtualThread()(before_mod) after_mod = tvm.tir.transform.StorageRewrite()(intermediate_mod) after_func = after_mod["main"] diff --git a/tests/python/tir-transform/test_tir_transform_lift_thread_binding.py b/tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_lift_thread_binding.py rename to tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py index eacd3bbcfe4c..a54c52d0cb80 100644 --- a/tests/python/tir-transform/test_tir_transform_lift_thread_binding.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, s_tir from tvm.script import tir as T @@ -131,7 +131,7 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): C[blockIdx_x // n, 0, blockIdx_x % n] = D_local[blockIdx_x // n, 0, blockIdx_x % n] * T.float32(0.088397790055248615) # fmt: on mod = tvm.IRModule({"main": before.with_attr("global_symbol", "main")}) - after = tir.transform.LiftThreadBinding()(mod) + after = s_tir.transform.LiftThreadBinding()(mod) tvm.ir.assert_structural_equal(expected.with_attr("global_symbol", "main"), after["main"]) diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py similarity index 95% rename from tests/python/tir-transform/test_tir_transform_loop_partition.py rename to tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py index 63bbf457e51d..9691f94f13dd 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py @@ -41,7 +41,7 @@ def func(n: T.int64, m: T.int64): T.evaluate(n) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -63,7 +63,7 @@ def func(n: T.int64, m: T.int64): T.evaluate(n) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -77,7 +77,7 @@ def func(m: T.int64, n: T.int64): T.evaluate(T.Select(T.likely(i * 4 + j < n), m, n)) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) @@ -90,8 +90,8 @@ def func(m: T.int64, n: T.int64): T.evaluate(T.Select(T.likely(i == 5), m, n)) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - mod = tvm.tir.transform.LoopPartition()(mod) + with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) @@ -107,7 +107,7 @@ def func(m: T.int64, n: T.int64): T.evaluate(m) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert isinstance(stmt.body.body, tvm.tir.IfThenElse) @@ -136,8 +136,8 @@ def func(m: T.int64, data: T.handle("float32"), out: T.handle("float32")): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - mod = tvm.tir.transform.LoopPartition()(mod) + with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -157,8 +157,8 @@ def func(A: T.Buffer((n * m,), "float16"), B: T.Buffer((n * m,), "float16")): ) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - mod = tvm.tir.transform.LoopPartition()(mod) + with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -178,8 +178,8 @@ def func(): T.evaluate(T.call_extern("float32", "cce_intrisic", i * tile, i * tile + tile)) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - mod = tvm.tir.transform.LoopPartition()(mod) + with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -199,8 +199,8 @@ def func(): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - mod = tvm.tir.transform.LoopPartition()(mod) + with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.s_tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -220,10 +220,10 @@ def partitioned_concat( def partition_from_scheduled_tir(prim_func, pass_cfg, do_flatten=True): with tvm.transform.PassContext(config=pass_cfg): mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) if do_flatten: mod = tvm.tir.transform.FlattenBuffer()(mod) - mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.s_tir.transform.LoopPartition()(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) return mod @@ -275,7 +275,7 @@ def concat_func_3( def test_condition_mutually_exclusive(): mod = partition_from_scheduled_tir( - concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}} + concat_func_3, {"s_tir.LoopPartition": {"partition_const_loop": True}} ) tvm.ir.assert_structural_equal( mod["main"], partitioned_concat_3.with_attr("global_symbol", "main") @@ -320,7 +320,7 @@ def partitioned_main( mod = partition_from_scheduled_tir( main, { - "tir.LoopPartition": { + "s_tir.LoopPartition": { "partition_const_loop": True, "unroll_loop_with_partition_hint_no_interval": True, } @@ -392,7 +392,7 @@ def partitioned_main(): mod = partition_from_scheduled_tir( main, { - "tir.LoopPartition": { + "s_tir.LoopPartition": { "partition_const_loop": True, "unroll_loop_with_partition_hint_no_interval": True, } @@ -427,7 +427,7 @@ def after(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: mod = partition_from_scheduled_tir( before, { - "tir.LoopPartition": { + "s_tir.LoopPartition": { "partition_const_loop": True, } }, @@ -475,7 +475,7 @@ def after( mod = partition_from_scheduled_tir( before, { - "tir.LoopPartition": { + "s_tir.LoopPartition": { "partition_const_loop": True, } }, @@ -725,7 +725,7 @@ def test_single_point_partition(origin, expected): mod = partition_from_scheduled_tir( origin, { - "tir.LoopPartition": { + "s_tir.LoopPartition": { "partition_const_loop": True, "unroll_loop_with_partition_hint_no_interval": True, } diff --git a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py rename to tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py index 8137ae674a8a..692952b84a48 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py @@ -20,7 +20,7 @@ import pytest import tvm import tvm.testing -from tvm import te +from tvm import te, s_tir from tvm.script import tir as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -28,7 +28,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) + mod = tvm.s_tir.transform.LowerCrossThreadReduction()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) @@ -37,7 +37,7 @@ def _check(original, transformed): def _check_fail(original): mod = tvm.IRModule.from_expr(original) with pytest.raises(ValueError): - tvm.tir.transform.LowerCrossThreadReduction()(mod) + tvm.s_tir.transform.LowerCrossThreadReduction()(mod) @T.prim_func diff --git a/tests/python/tir-transform/test_tir_transform_lower_init_block.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py similarity index 96% rename from tests/python/tir-transform/test_tir_transform_lower_init_block.py rename to tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py index 4e1a1cbfa6d7..147ca6dcc8ca 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_init_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te +from tvm import te, s_tir from tvm.script import tir as T # pylint: disable=no-self-argument @@ -95,13 +95,13 @@ def main(a: T.handle, b: T.handle) -> None: def test_lower_reduction(): origin_mod = WithInit - mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + mod = tvm.s_tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, WithBranch, True) def test_lower_match_buffer(): origin_mod = InitWithMatchBuffer - mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + mod = tvm.s_tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer, True) diff --git a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_lower_match_buffer.py rename to tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py index 41c53437e98c..5bdc3a4c4cea 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py @@ -19,12 +19,13 @@ import tvm import tvm.testing +import tvm.s_tir from tvm.script import tir as T def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerMatchBuffer()(mod) + mod = tvm.s_tir.transform.LowerMatchBuffer()(mod) mod = tvm.tir.transform.Simplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) @@ -32,7 +33,7 @@ def _check(original, transformed): def _check_fail(original): mod = tvm.IRModule.from_expr(original) with pytest.raises(tvm.TVMError): - mod = tvm.tir.transform.LowerMatchBuffer()(mod) + mod = tvm.s_tir.transform.LowerMatchBuffer()(mod) @T.prim_func diff --git a/tests/python/tir-transform/test_tir_transform_lower_opaque_block.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py similarity index 98% rename from tests/python/tir-transform/test_tir_transform_lower_opaque_block.py rename to tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py index 294cdc42d2e3..94cf5542005e 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_opaque_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py @@ -16,6 +16,7 @@ # under the License. import tvm import tvm.testing +import tvm.s_tir from tvm import te from tvm.script import tir as T @@ -23,7 +24,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.Simplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True @@ -351,7 +352,7 @@ def test_symbolic_strided_buffer(): def test_annotated_loops(): mod = tvm.IRModule.from_expr(annotated_loops.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) attr1 = mod["main"].body attr2 = attr1.body attr3 = attr2.body @@ -370,7 +371,7 @@ def annotated_block() -> None: T.evaluate(0) mod = tvm.IRModule.from_expr(annotated_block.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) attr1 = mod["main"].body attr2 = attr1.body attr3 = attr2.body @@ -395,7 +396,7 @@ def after(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")): B[i] = A[i] + 1.0 mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) diff --git a/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py rename to tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py index f6fe611c4e8c..b07bd97a57b5 100644 --- a/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py @@ -16,6 +16,7 @@ # under the License. import tvm import tvm.testing +from tvm import s_tir from tvm.script import tir as T @@ -122,7 +123,7 @@ def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float3 def _check(before, expected): - after = tvm.tir.transform.ManifestSharedMemoryLocalStage()(before) + after = tvm.s_tir.transform.ManifestSharedMemoryLocalStage()(before) tvm.ir.assert_structural_equal(after, expected) diff --git a/tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py rename to tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py index f454699d1c9a..d624677d37e4 100644 --- a/tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py @@ -16,7 +16,7 @@ # under the License. import tvm -from tvm import te +from tvm import te, s_tir from tvm.script import tir as T import sys import pytest @@ -1099,7 +1099,7 @@ def main(C: T.Buffer((1024, 1024), "float32")): def _check(original, transformed): - mod = tvm.tir.transform.LowerAutoCopy()(original) + mod = tvm.s_tir.transform.LowerAutoCopy()(original) tvm.ir.assert_structural_equal(mod, transformed, True) @@ -1156,7 +1156,7 @@ def prod(arr): def test_auto_padding(): - mod = tvm.tir.transform.LowerAutoCopy()(Transpose) + mod = tvm.s_tir.transform.LowerAutoCopy()(Transpose) mod = tvm.tir.transform.FlattenBuffer()(mod) verify_single_allocation(mod["main"].body, 16 * 130) diff --git a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py rename to tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py index 053c42e1fcd8..3cd6b5841037 100644 --- a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py @@ -20,14 +20,14 @@ import tvm.testing from tvm import te from tvm.script import tir as T -from tvm import tir +from tvm import tir, s_tir from tvm.s_tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tvm.s_tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) diff --git a/tests/python/tir-transform/test_tir_transform_unify_thread_binding.py b/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py similarity index 98% rename from tests/python/tir-transform/test_tir_transform_unify_thread_binding.py rename to tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py index 89b17719ac41..1e2b08244362 100644 --- a/tests/python/tir-transform/test_tir_transform_unify_thread_binding.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py @@ -19,13 +19,13 @@ import tvm import tvm.testing -from tvm import te +from tvm import te, s_tir from tvm.script import tir as T def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.UnifyThreadBinding()(mod) + mod = tvm.s_tir.transform.UnifyThreadBinding()(mod) mod = tvm.tir.transform.Simplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True @@ -35,7 +35,7 @@ def _check(original, transformed): def _check_fail(original): mod = tvm.IRModule.from_expr(original) with pytest.raises(ValueError): - tvm.tir.transform.UnifyThreadBinding()(mod) + tvm.s_tir.transform.UnifyThreadBinding()(mod) @T.prim_func diff --git a/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py b/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py index 845d157c37f9..64f6ca278d28 100644 --- a/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py +++ b/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py @@ -60,8 +60,8 @@ def test_scale_by(primFunc, size): sch.compute_at(cache_block, flat) mod = sch.mod - mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"' sizes = sizes["main"] @@ -102,9 +102,9 @@ def matmul_mix_scope(a: T.handle, b: T.handle, c: T.handle) -> None: def test_matmul_mix_scope(scope, size): """Test calculate allocated bytes per scope""" mod = tvm.IRModule({"main": matmul_mix_scope}) - mod = tvm.tir.transform.LowerInitBlock()(mod) - mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerInitBlock()(mod) + mod = tvm.s_tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"' sizes = sizes["main"] @@ -120,8 +120,8 @@ def apply_schedule(sch, func_name): sch = tvm.s_tir.Schedule(Module, debug_mask="all") apply_schedule(sch, "scale_by_two") apply_schedule(sch, "scale_by_two_three") - mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.ConvertBlocksToOpaque()(sch.mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod) assert "scale_by_two" in sizes, "Values for scale_by_two not found" scale_by_two_sizes = sizes["scale_by_two"] diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index f5cf71613e39..5ab0cfe6f0d0 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -131,7 +131,7 @@ def test_inject_async_copy(): f = generate_global_to_shared_vectorized_copy(dtype, vec_size) mod = tvm.IRModule.from_expr(f) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) if vec_size > 1: mod = tvm.tir.transform.VectorizeLoop()(mod) @@ -159,7 +159,7 @@ def test_inject_async_copy_shared_dyn(): f = ptx_global_to_shared_dyn_copy_fp16x8 mod = tvm.IRModule.from_expr(f) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.MergeSharedMemoryAllocations()(mod) @@ -221,7 +221,7 @@ def test_inject_async_copy_barrier(): f = ptx_global_to_shared_copy_fp32x1_barrier mod = tvm.IRModule.from_expr(f) - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index cc64db9cf3e7..3819e19edd94 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3598,7 +3598,7 @@ def func(A: T.Buffer(128, "float32"), C: T.Buffer(128, "float32")): for i in T.thread_binding(128, thread="threadIdx.x"): C[i] = B[i] + 2.0 - mod = tvm.tir.transform.LowerOpaqueBlock()( + mod = tvm.s_tir.transform.LowerOpaqueBlock()( tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) ) return mod["main"]