From ba0f4a232c9533c07076971cdd6c1eade89402dc Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 7 Feb 2026 08:59:35 -0500 Subject: [PATCH] [REFACTOR][S-TIR] Migrate tir/schedule to s_tir This PR lifts the previous tir/schedule to s_tir namespace. --- include/tvm/meta_schedule/cost_model.h | 2 +- include/tvm/meta_schedule/database.h | 20 +-- include/tvm/meta_schedule/measure_candidate.h | 6 +- include/tvm/meta_schedule/mutator.h | 16 +- include/tvm/meta_schedule/postproc.h | 8 +- .../meta_schedule/schedule/cuda/thread_bind.h | 17 +- .../meta_schedule/schedule/generic/winograd.h | 4 +- include/tvm/meta_schedule/schedule_rule.h | 11 +- include/tvm/meta_schedule/search_strategy.h | 8 +- include/tvm/meta_schedule/space_generator.h | 8 +- .../tvm/{tir => s_tir}/schedule/instruction.h | 19 ++- .../tvm/{tir => s_tir}/schedule/schedule.h | 21 +-- include/tvm/{tir => s_tir}/schedule/state.h | 15 +- include/tvm/{tir => s_tir}/schedule/trace.h | 15 +- python/tvm/dlight/analysis/common_analysis.py | 8 +- python/tvm/s_tir/schedule/_ffi_api.py | 2 +- python/tvm/s_tir/schedule/analysis.py | 4 +- python/tvm/s_tir/schedule/instruction.py | 4 +- python/tvm/s_tir/schedule/schedule.py | 6 +- python/tvm/s_tir/schedule/state.py | 2 +- python/tvm/s_tir/schedule/trace.py | 2 +- python/tvm/tir/analysis/analysis.py | 2 +- src/meta_schedule/database/database.cc | 31 ++-- .../database/schedule_fn_database.cc | 20 +-- .../feature_extractor/per_store_feature.cc | 35 ++-- .../mutator/mutate_compute_location.cc | 22 +-- src/meta_schedule/mutator/mutate_parallel.cc | 39 ++--- .../mutator/mutate_thread_binding.cc | 22 +-- src/meta_schedule/mutator/mutate_tile_size.cc | 20 +-- src/meta_schedule/mutator/mutate_unroll.cc | 21 +-- src/meta_schedule/mutator/mutator.cc | 6 +- .../disallow_async_strided_mem_copy.cc | 9 +- .../postproc/disallow_dynamic_loop.cc | 9 +- src/meta_schedule/postproc/postproc.cc | 2 +- .../postproc/rewrite_cooperative_fetch.cc | 53 +++--- src/meta_schedule/postproc/rewrite_layout.cc | 13 +- .../rewrite_parallel_vectorize_unroll.cc | 52 +++--- .../postproc/rewrite_reduction_block.cc | 23 +-- .../postproc/rewrite_tensorize.cc | 20 +-- .../postproc/rewrite_unbound_block.cc | 19 ++- src/meta_schedule/postproc/verify_gpu_code.cc | 15 +- .../postproc/verify_vtcm_limit.cc | 2 +- src/meta_schedule/schedule/cpu/winograd.cc | 9 +- .../schedule/cuda/thread_bind.cc | 22 ++- src/meta_schedule/schedule/cuda/winograd.cc | 9 +- .../schedule/generic/winograd.cc | 4 + .../schedule_rule/add_rfactor.cc | 24 +-- .../schedule_rule/apply_custom_rule.cc | 9 +- src/meta_schedule/schedule_rule/auto_bind.cc | 7 +- .../schedule_rule/auto_inline.cc | 38 +++-- .../schedule_rule/cross_thread_reduction.cc | 74 +++++---- .../schedule_rule/multi_level_tiling.cc | 36 ++-- .../schedule_rule/multi_level_tiling.h | 32 ++-- .../multi_level_tiling_tensor_core.cc | 96 +++++------ .../multi_level_tiling_wide_vector.cc | 35 ++-- .../multi_level_tiling_with_intrin.cc | 15 +- .../parallel_vectorize_unroll.cc | 11 +- .../schedule_rule/random_compute_location.cc | 29 ++-- .../schedule_rule/schedule_rule.cc | 4 +- .../search_strategy/evolutionary_search.cc | 22 +-- .../search_strategy/replay_func.cc | 8 +- .../search_strategy/replay_trace.cc | 20 +-- .../search_strategy/search_strategy.cc | 9 +- .../space_generator/post_order_apply.cc | 22 +-- .../space_generator/schedule_fn.cc | 12 +- .../space_generator/space_generator.cc | 2 +- .../space_generator/space_generator_union.cc | 6 +- .../task_scheduler/gradient_based.cc | 2 +- .../task_scheduler/task_scheduler.cc | 8 +- src/meta_schedule/trace_apply.cc | 14 +- src/meta_schedule/trace_apply.h | 6 +- src/meta_schedule/utils.h | 38 ++--- .../transform/legalize_redistribute.cc | 2 +- .../distributed/transform/lower_distir.cc | 2 +- .../lower_global_view_to_local_view.cc | 3 +- src/relax/transform/meta_schedule.cc | 10 +- .../transform/split_call_tir_by_pattern.cc | 4 +- src/{tir => s_tir}/schedule/analysis.h | 29 ++-- .../schedule/analysis/analysis.cc | 56 +++---- .../schedule/analysis/layout.cc | 7 +- .../schedule/analysis/reducer.cc | 5 +- .../schedule/analysis/verify.cc | 5 +- .../schedule/concrete_schedule.cc | 156 +++++++++--------- .../schedule/concrete_schedule.h | 11 +- src/{tir => s_tir}/schedule/error.cc | 5 +- src/{tir => s_tir}/schedule/error.h | 13 +- src/{tir => s_tir}/schedule/instruction.cc | 9 +- .../schedule/instruction_traits.h | 17 +- src/{tir => s_tir}/schedule/ir_comparator.cc | 5 +- src/{tir => s_tir}/schedule/ir_comparator.h | 11 +- src/{tir => s_tir}/schedule/primitive.h | 17 +- .../schedule/primitive/annotate.cc | 9 +- .../primitive/annotate_buffer_access.cc | 11 +- .../schedule/primitive/block_annotate.cc | 22 +-- .../schedule/primitive/blockize_tensorize.cc | 25 +-- .../schedule/primitive/cache_index.cc | 11 +- .../schedule/primitive/cache_read_write.cc | 27 +-- .../schedule/primitive/compute_at.cc | 9 +- .../schedule/primitive/compute_inline.cc | 11 +- .../schedule/primitive/decompose_padding.cc | 11 +- .../schedule/primitive/for_kind.cc | 13 +- .../schedule/primitive/get_block_loop.cc | 23 +-- .../schedule/primitive/hide_buffer_access.cc | 9 +- .../primitive/layout_transformation.cc | 11 +- .../schedule/primitive/loop_transformation.cc | 21 +-- .../schedule/primitive/pad_einsum.cc | 7 +- .../schedule/primitive/read_write_at.cc | 9 +- .../schedule/primitive/reduction.cc | 11 +- .../primitive/reorder_block_iter_var.cc | 7 +- .../schedule/primitive/rolling_buffer.cc | 9 +- .../schedule/primitive/sampling.cc | 15 +- src/{tir => s_tir}/schedule/schedule.cc | 146 ++++++++-------- src/{tir => s_tir}/schedule/state.cc | 15 +- src/{tir => s_tir}/schedule/trace.cc | 27 +-- .../schedule/traced_schedule.cc | 17 +- src/{tir => s_tir}/schedule/traced_schedule.h | 11 +- src/{tir => s_tir}/schedule/transform.cc | 18 +- src/{tir => s_tir}/schedule/transform.h | 25 +-- src/{tir => s_tir}/schedule/utils.h | 32 ++-- src/s_tir/transform/compact_buffer_region.cc | 2 +- .../transform/inject_software_pipeline.cc | 2 +- .../transform/lower_cross_thread_reduction.cc | 4 +- .../manifest_shared_memory_local_stage.cc | 2 +- .../transform/memhammer_lower_auto_copy.cc | 2 +- src/s_tir/transform/memhammer_rewrite_rule.h | 2 +- src/tir/analysis/oob_checker.cc | 4 +- src/tir/transforms/default_gpu_schedule.cc | 22 +-- tests/python/dlight/test_primitives.py | 2 +- 128 files changed, 1184 insertions(+), 1043 deletions(-) rename include/tvm/{tir => s_tir}/schedule/instruction.h (94%) rename include/tvm/{tir => s_tir}/schedule/schedule.h (98%) rename include/tvm/{tir => s_tir}/schedule/state.h (96%) rename include/tvm/{tir => s_tir}/schedule/trace.h (95%) rename src/{tir => s_tir}/schedule/analysis.h (98%) rename src/{tir => s_tir}/schedule/analysis/analysis.cc (97%) rename src/{tir => s_tir}/schedule/analysis/layout.cc (98%) rename src/{tir => s_tir}/schedule/analysis/reducer.cc (99%) rename src/{tir => s_tir}/schedule/analysis/verify.cc (99%) rename src/{tir => s_tir}/schedule/concrete_schedule.cc (87%) rename src/{tir => s_tir}/schedule/concrete_schedule.h (98%) rename src/{tir => s_tir}/schedule/error.cc (96%) rename src/{tir => s_tir}/schedule/error.h (93%) rename src/{tir => s_tir}/schedule/instruction.cc (96%) rename src/{tir => s_tir}/schedule/instruction_traits.h (98%) rename src/{tir => s_tir}/schedule/ir_comparator.cc (99%) rename src/{tir => s_tir}/schedule/ir_comparator.h (97%) rename src/{tir => s_tir}/schedule/primitive.h (99%) rename src/{tir => s_tir}/schedule/primitive/annotate.cc (97%) rename src/{tir => s_tir}/schedule/primitive/annotate_buffer_access.cc (96%) rename src/{tir => s_tir}/schedule/primitive/block_annotate.cc (96%) rename src/{tir => s_tir}/schedule/primitive/blockize_tensorize.cc (98%) rename src/{tir => s_tir}/schedule/primitive/cache_index.cc (98%) rename src/{tir => s_tir}/schedule/primitive/cache_read_write.cc (99%) rename src/{tir => s_tir}/schedule/primitive/compute_at.cc (99%) rename src/{tir => s_tir}/schedule/primitive/compute_inline.cc (99%) rename src/{tir => s_tir}/schedule/primitive/decompose_padding.cc (98%) rename src/{tir => s_tir}/schedule/primitive/for_kind.cc (97%) rename src/{tir => s_tir}/schedule/primitive/get_block_loop.cc (93%) rename src/{tir => s_tir}/schedule/primitive/hide_buffer_access.cc (97%) rename src/{tir => s_tir}/schedule/primitive/layout_transformation.cc (99%) rename src/{tir => s_tir}/schedule/primitive/loop_transformation.cc (99%) rename src/{tir => s_tir}/schedule/primitive/pad_einsum.cc (99%) rename src/{tir => s_tir}/schedule/primitive/read_write_at.cc (99%) rename src/{tir => s_tir}/schedule/primitive/reduction.cc (99%) rename src/{tir => s_tir}/schedule/primitive/reorder_block_iter_var.cc (97%) rename src/{tir => s_tir}/schedule/primitive/rolling_buffer.cc (98%) rename src/{tir => s_tir}/schedule/primitive/sampling.cc (98%) rename src/{tir => s_tir}/schedule/schedule.cc (67%) rename src/{tir => s_tir}/schedule/state.cc (99%) rename src/{tir => s_tir}/schedule/trace.cc (96%) rename src/{tir => s_tir}/schedule/traced_schedule.cc (98%) rename src/{tir => s_tir}/schedule/traced_schedule.h (97%) rename src/{tir => s_tir}/schedule/transform.cc (97%) rename src/{tir => s_tir}/schedule/transform.h (93%) rename src/{tir => s_tir}/schedule/utils.h (95%) diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index aaf4665c2729..30a398a01ccb 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 6f6b8bfca8d6..7ca377a0a797 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -27,9 +27,9 @@ #include #include #include +#include +#include #include -#include -#include #include #include @@ -114,7 +114,7 @@ class MeasureCandidate; class TuningRecordNode : public runtime::Object { public: /*! \brief The trace tuned. */ - tir::Trace trace; + s_tir::Trace trace; /*! \brief The workload. */ Workload workload{ffi::UnsafeInit()}; /*! \brief The profiling result in seconds. */ @@ -166,7 +166,7 @@ class TuningRecord : public runtime::ObjectRef { \param target The target of the tuning record. \param args_info The argument information of the tuning record. */ - TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload, + TVM_DLL explicit TuningRecord(s_tir::Trace trace, Workload workload, ffi::Optional> run_secs, ffi::Optional target, ffi::Optional> args_info); @@ -251,8 +251,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ - virtual ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, - const ffi::String& workload_name); + virtual ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Query the best IRModule of the given workload from the database. * \param mod The IRModule to be searched for. @@ -343,7 +343,7 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ - using FQuerySchedule = ffi::TypedFunction( + using FQuerySchedule = ffi::TypedFunction( const IRModule&, const Target&, const ffi::String&)>; /*! * \brief The function type of `QueryIRModule` method. @@ -432,8 +432,8 @@ class PyDatabaseNode : public DatabaseNode { } } - ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, - const ffi::String& workload_name) final { + ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_schedule == nullptr) { return DatabaseNode::QuerySchedule(mod, target, workload_name); } else { @@ -483,7 +483,7 @@ class Database : public runtime::ObjectRef { * and returns a boolean indicating if the schedule is successful. * \param mod_eq_name A string to specify the module equality testing and hashing method. */ - TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction schedule_fn, + TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction schedule_fn, ffi::String mod_eq_name = "structural"); /*! * \brief Create a default database that uses JSON file for tuning records. diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index 557e9a3139d2..dc87ad2e476a 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include namespace tvm { namespace meta_schedule { @@ -33,7 +33,7 @@ namespace meta_schedule { class MeasureCandidateNode : public runtime::Object { public: /*! \brief The schedule for measurement. */ - tir::Schedule sch; + s_tir::Schedule sch; /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ ffi::Array args_info; @@ -57,7 +57,7 @@ class MeasureCandidate : public runtime::ObjectRef { * \param sch The schedule for measurement. * \param args_info The argument information, e.g., (shape, dtype) for tensors. */ - TVM_DLL MeasureCandidate(tir::Schedule sch, ffi::Array args_info); + TVM_DLL MeasureCandidate(s_tir::Schedule sch, ffi::Array args_info); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(MeasureCandidate, ObjectRef, MeasureCandidateNode); }; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 05489c755217..da129f8805ba 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -24,9 +24,9 @@ #include #include #include +#include +#include #include -#include -#include namespace tvm { namespace meta_schedule { @@ -58,8 +58,8 @@ class MutatorNode : public runtime::Object { * \param rand_state The random state for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - virtual ffi::Optional Apply( - const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0; + virtual ffi::Optional Apply( + const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0; /*! * \brief Clone the mutator. @@ -87,8 +87,8 @@ class Mutator : public runtime::ObjectRef { * \param trace The given trace for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - using FApply = ffi::TypedFunction( - const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; + using FApply = ffi::TypedFunction( + const s_tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; /*! * \brief Clone the mutator. * \return The cloned mutator. @@ -168,8 +168,8 @@ class PyMutatorNode : public MutatorNode { } void InitializeWithTuneContext(const TuneContext& context) final; - ffi::Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) final; + ffi::Optional Apply( + const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) final; Mutator Clone() const final; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyMutator", PyMutatorNode, MutatorNode); }; diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 948f75210701..54dc8a23d5d5 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include namespace tvm { namespace meta_schedule { @@ -56,7 +56,7 @@ class PostprocNode : public runtime::Object { * \param sch The schedule to be post processed. * \return Whether the postprocessor was successfully applied. */ - virtual bool Apply(const tir::Schedule& sch) = 0; + virtual bool Apply(const s_tir::Schedule& sch) = 0; /*! * \brief Clone the postprocessor. @@ -84,7 +84,7 @@ class Postproc : public runtime::ObjectRef { * \param sch The schedule to be post processed. * \return Whether the postprocessor was successfully applied. */ - using FApply = ffi::TypedFunction; + using FApply = ffi::TypedFunction; /*! * \brief Clone the postprocessor. * \return The cloned postprocessor. @@ -205,7 +205,7 @@ class PyPostprocNode : public PostprocNode { } void InitializeWithTuneContext(const TuneContext& context) final; - bool Apply(const tir::Schedule& sch) final; + bool Apply(const s_tir::Schedule& sch) final; Postproc Clone() const final; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyPostproc", PyPostprocNode, PostprocNode); }; diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index 15ed73716873..b5781cb3739d 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -19,7 +19,7 @@ #ifndef TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_ #define TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_ -#include +#include #include #include @@ -35,8 +35,8 @@ namespace meta_schedule { * \param thread_extents The candidate thread extents. * \return A sampler that returns a random thread extent. */ -std::function MakeFactorSampler(tir::Schedule sch, - ffi::Array thread_extents); +std::function MakeFactorSampler(s_tir::Schedule sch, + ffi::Array thread_extents); /*! * \brief Bind blockIdx.x and threadIdx.x to the given loop @@ -47,9 +47,10 @@ std::function MakeFactorSampler(tir::Schedule sch, * \param get_factor A function that returns the tiling factor. * \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest. */ -ffi::Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // - int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor = nullptr); +ffi::Array BindSpatialLoop( + s_tir::Schedule sch, s_tir::LoopRV loop, // + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor = nullptr); /*! * \brief Bind the given block if it is not bound to blockIdx or threadIdx. @@ -59,9 +60,9 @@ ffi::Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // * \param max_threads_per_block The maximum number of threads allowed. * \param get_factor A function that returns the tiling factor. */ -void BindBlockThreadIdx(tir::Schedule sch, tir::SBlockRV block, // +void BindBlockThreadIdx(s_tir::Schedule sch, s_tir::SBlockRV block, // int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor = nullptr); + std::function get_factor = nullptr); } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/schedule/generic/winograd.h b/include/tvm/meta_schedule/schedule/generic/winograd.h index 4a891fbaf1fc..b010f52e3a0d 100644 --- a/include/tvm/meta_schedule/schedule/generic/winograd.h +++ b/include/tvm/meta_schedule/schedule/generic/winograd.h @@ -19,7 +19,7 @@ #ifndef TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_ #define TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_ -#include +#include namespace tvm { namespace meta_schedule { @@ -29,7 +29,7 @@ namespace meta_schedule { * If there is a constant winograd transform matrix, inline it. * \return The only producer block. */ -tir::SBlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::SBlockRV block); +s_tir::SBlockRV GetWinogradProducerAndInlineConst(s_tir::Schedule sch, s_tir::SBlockRV block); } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 259b6ac12483..db90adfbed79 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include namespace tvm { namespace meta_schedule { @@ -60,7 +60,8 @@ class ScheduleRuleNode : public runtime::Object { * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block) = 0; + virtual ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block) = 0; /*! * \brief Deep clone the schedule rule. @@ -89,8 +90,8 @@ class ScheduleRule : public runtime::ObjectRef { * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - using FApply = - ffi::TypedFunction(const tir::Schedule&, const tir::SBlockRV&)>; + using FApply = ffi::TypedFunction(const s_tir::Schedule&, + const s_tir::SBlockRV&)>; /*! * \brief Get the schedule rule as string with name. * \return The string of the schedule rule. @@ -343,7 +344,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode { } void InitializeWithTuneContext(const TuneContext& context) final; - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block) final; + ffi::Array Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block) final; ScheduleRule Clone() const final; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyScheduleRule", PyScheduleRuleNode, ScheduleRuleNode); diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 714c43470f05..1560d9fa6b89 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include namespace tvm { namespace meta_schedule { @@ -98,7 +98,7 @@ class SearchStrategyNode : public runtime::Object { * and reset the search strategy. */ virtual void PreTuning(int max_trials, int num_trials_per_iter, - const ffi::Array& design_spaces, + const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) = 0; @@ -148,7 +148,7 @@ class SearchStrategy : public runtime::ObjectRef { * \brief The function type of `PreTuning` method. */ using FPreTuning = ffi::TypedFunction&, + int max_trials, int num_trials_per_iter, const ffi::Array&, const ffi::Optional&, const ffi::Optional&)>; /*! \brief The function type of `PostTuning` method. */ using FPostTuning = ffi::TypedFunction; @@ -255,7 +255,7 @@ class PySearchStrategyNode : public SearchStrategyNode { void InitializeWithTuneContext(const TuneContext& context) final; void PreTuning(int max_trials, int num_trials_per_iter, - const ffi::Array& design_spaces, + const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) final; void PostTuning() final; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 460a41e44a20..85e965acfda7 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -27,8 +27,8 @@ #include #include #include +#include #include -#include namespace tvm { namespace meta_schedule { @@ -105,7 +105,7 @@ class SpaceGeneratorNode : public runtime::Object { * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - virtual ffi::Array GenerateDesignSpace(const IRModule& mod) = 0; + virtual ffi::Array GenerateDesignSpace(const IRModule& mod) = 0; /*! * \brief Clone the space generator. @@ -140,7 +140,7 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; + using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; /*! * \brief The function type of `Clone` method. * \return The cloned space generator. @@ -232,7 +232,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { } void InitializeWithTuneContext(const TuneContext& context) final; - ffi::Array GenerateDesignSpace(const IRModule& mod) final; + ffi::Array GenerateDesignSpace(const IRModule& mod) final; SpaceGenerator Clone() const final; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PySpaceGenerator", PySpaceGeneratorNode, SpaceGeneratorNode); diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/s_tir/schedule/instruction.h similarity index 94% rename from include/tvm/tir/schedule/instruction.h rename to include/tvm/s_tir/schedule/instruction.h index c4ee3ce03d15..d571b8f2866e 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/s_tir/schedule/instruction.h @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_H_ -#define TVM_TIR_SCHEDULE_INSTRUCTION_H_ +#ifndef TVM_S_TIR_SCHEDULE_INSTRUCTION_H_ +#define TVM_S_TIR_SCHEDULE_INSTRUCTION_H_ #include @@ -29,7 +29,8 @@ namespace tvm { template class AttrRegistry; -namespace tir { +namespace s_tir { +using namespace tvm::tir; // Forward declaration class Schedule; @@ -121,7 +122,7 @@ class InstructionKindNode : public runtime::Object { /*! \brief Checks if the instruction kind is EnterPostproc */ bool IsPostproc() const; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.InstructionKind", InstructionKindNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.InstructionKind", InstructionKindNode, runtime::Object); }; /*! @@ -179,7 +180,7 @@ class InstructionNode : public runtime::Object { .def_ro("attrs", &InstructionNode::attrs) .def_ro("outputs", &InstructionNode::outputs); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Instruction", InstructionNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Instruction", InstructionNode, runtime::Object); }; /*! @@ -207,7 +208,7 @@ class Instruction : public runtime::ObjectRef { * \sa TVM_REGISTER_INST_KIND */ #define TVM_INST_KIND_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::tir::InstructionKindRegEntry& __make_##InstructionKind + static DMLC_ATTRIBUTE_UNUSED ::tvm::s_tir::InstructionKindRegEntry& __make_##InstructionKind /*! * \brief Register an InstructionKind @@ -228,7 +229,7 @@ class Instruction : public runtime::ObjectRef { */ #define TVM_REGISTER_INST_KIND(InstructionKindName) \ TVM_STR_CONCAT(TVM_INST_KIND_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::tir::InstructionKindRegEntry::RegisterOrGet(InstructionKindName).set_name() + ::tvm::s_tir::InstructionKindRegEntry::RegisterOrGet(InstructionKindName).set_name() /*! \brief An entry in the registry of InstructionKind */ class InstructionKindRegEntry { @@ -282,7 +283,7 @@ class InstructionKindRegEntry { friend class InstructionKind; }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_INSTRUCTION_H_ +#endif // TVM_S_TIR_SCHEDULE_INSTRUCTION_H_ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/s_tir/schedule/schedule.h similarity index 98% rename from include/tvm/tir/schedule/schedule.h rename to include/tvm/s_tir/schedule/schedule.h index e346eb458b9f..4a1bbe207d4e 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/s_tir/schedule/schedule.h @@ -16,16 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ -#define TVM_TIR_SCHEDULE_SCHEDULE_H_ +#ifndef TVM_S_TIR_SCHEDULE_SCHEDULE_H_ +#define TVM_S_TIR_SCHEDULE_SCHEDULE_H_ +#include +#include #include #include -#include -#include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief The level of detailed error message rendering */ enum class ScheduleErrorRenderLevel : int32_t { @@ -54,7 +55,7 @@ class SBlockRVNode : public runtime::Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockRV", SBlockRVNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockRV", SBlockRVNode, runtime::Object); }; /*! @@ -77,7 +78,7 @@ class LoopRVNode : public runtime::Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LoopRV", LoopRVNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.LoopRV", LoopRVNode, runtime::Object); }; /*! @@ -110,7 +111,7 @@ class ScheduleNode : public runtime::Object { virtual ~ScheduleNode() = default; static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Schedule", ScheduleNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Schedule", ScheduleNode, runtime::Object); public: /*! \brief Get the IRModule associated with this schedule. */ @@ -931,7 +932,7 @@ class Schedule : public runtime::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Schedule, runtime::ObjectRef, ScheduleNode); }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_SCHEDULE_H_ +#endif // TVM_S_TIR_SCHEDULE_SCHEDULE_H_ diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/s_tir/schedule/state.h similarity index 96% rename from include/tvm/tir/schedule/state.h rename to include/tvm/s_tir/schedule/state.h index fffa25e19fd6..821125037c5b 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/s_tir/schedule/state.h @@ -17,11 +17,11 @@ * under the License. */ /*! - * \file tvm/tir/schedule/state.h + * \file tvm/s_tir/schedule/state.h * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling. */ -#ifndef TVM_TIR_SCHEDULE_STATE_H_ -#define TVM_TIR_SCHEDULE_STATE_H_ +#ifndef TVM_S_TIR_SCHEDULE_STATE_H_ +#define TVM_S_TIR_SCHEDULE_STATE_H_ #include #include @@ -32,7 +32,8 @@ #include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief The information about a TensorIR block, it contains two categories of information @@ -157,7 +158,7 @@ class ScheduleStateNode : public Object { TVM_DLL void DebugVerify() const; static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ScheduleState", ScheduleStateNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.ScheduleState", ScheduleStateNode, Object); /******** Property of blocks ********/ /*! \brief Returns the SBlockInfo correpsonding to the block sref */ @@ -221,7 +222,7 @@ class ScheduleState : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScheduleState, ObjectRef, ScheduleStateNode); }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_STATE_H_ +#endif // TVM_S_TIR_SCHEDULE_STATE_H_ diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/s_tir/schedule/trace.h similarity index 95% rename from include/tvm/tir/schedule/trace.h rename to include/tvm/s_tir/schedule/trace.h index f5aa7cb5ffd6..5640c4b7f50e 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/s_tir/schedule/trace.h @@ -16,13 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_TRACE_H_ -#define TVM_TIR_SCHEDULE_TRACE_H_ +#ifndef TVM_S_TIR_SCHEDULE_TRACE_H_ +#define TVM_S_TIR_SCHEDULE_TRACE_H_ -#include +#include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; // Forward declaration class Trace; @@ -70,7 +71,7 @@ class TraceNode : public runtime::Object { } static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Trace", TraceNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Trace", TraceNode, runtime::Object); public: /*! @@ -160,7 +161,7 @@ class Trace : public runtime::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Trace, runtime::ObjectRef, TraceNode); }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_TRACE_H_ +#endif // TVM_S_TIR_SCHEDULE_TRACE_H_ diff --git a/python/tvm/dlight/analysis/common_analysis.py b/python/tvm/dlight/analysis/common_analysis.py index 938f69bfb4c6..1f81c3ccef41 100644 --- a/python/tvm/dlight/analysis/common_analysis.py +++ b/python/tvm/dlight/analysis/common_analysis.py @@ -63,7 +63,7 @@ def __repr__(self) -> str: return str(self) -get_sblockrealize = get_global_func("tir.schedule.GetSBlockRealize") +get_sblockrealize = get_global_func("s_tir.schedule.GetSBlockRealize") # BufferIndex Types Index = namedtuple("Index", ["sub"]) # c RemIndex = namedtuple("RemIndex", ["sub", "div"]) # c%len @@ -242,7 +242,7 @@ def is_layout_transform(self, sch: s_tir.Schedule) -> bool: and len(self.write_bufs(sch)) == 1 and len(self.read_bufs(sch)) == 1 and not self.is_elementwise(sch) - and not get_global_func("tir.schedule.HasIfThenElse")(sch.get(self.block_rv)) + and not get_global_func("s_tir.schedule.HasIfThenElse")(sch.get(self.block_rv)) ) def is_data_pad(self, sch: s_tir.Schedule) -> bool: @@ -254,7 +254,7 @@ def is_data_pad(self, sch: s_tir.Schedule) -> bool: and not self.is_elementwise(sch) and len(self.write_bufs(sch)[0].buf_region.region) == len(self.read_bufs(sch)[0].buf_region.region) - and get_global_func("tir.schedule.HasIfThenElse")(sch.get(self.block_rv)) + and get_global_func("s_tir.schedule.HasIfThenElse")(sch.get(self.block_rv)) ) def is_convolution(self) -> bool: @@ -280,7 +280,7 @@ def __repr__(self) -> str: return str(self) -_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") +_normalize_prim_func = get_global_func("s_tir.schedule.NormalizePrimFunc") def normalize_prim_func(sch: s_tir.Schedule) -> Optional[List[SBlockInfo]]: diff --git a/python/tvm/s_tir/schedule/_ffi_api.py b/python/tvm/s_tir/schedule/_ffi_api.py index 2295d58d787e..910b9c483ce6 100644 --- a/python/tvm/s_tir/schedule/_ffi_api.py +++ b/python/tvm/s_tir/schedule/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.s_tir.schedule""" import tvm_ffi -tvm_ffi.init_ffi_api("tir.schedule", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("s_tir.schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/s_tir/schedule/analysis.py b/python/tvm/s_tir/schedule/analysis.py index b5648e3e158a..eb2ba53e9ca3 100644 --- a/python/tvm/s_tir/schedule/analysis.py +++ b/python/tvm/s_tir/schedule/analysis.py @@ -62,7 +62,7 @@ def suggest_index_map( ) -@tvm_ffi.register_object("tir.schedule.TensorizeInfo") +@tvm_ffi.register_object("s_tir.schedule.TensorizeInfo") class TensorizeInfo(Object): """Necessary information used for tensorization.""" @@ -90,7 +90,7 @@ def get_tensorize_loop_mapping( return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func, allow_padding) # type: ignore -@tvm_ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") +@tvm_ffi.register_object("s_tir.schedule.AutoTensorizeMappingInfo") class AutoTensorizeMappingInfo(Object): """Necessary information used to perform transformations for tensorization.""" diff --git a/python/tvm/s_tir/schedule/instruction.py b/python/tvm/s_tir/schedule/instruction.py index 26428d320dde..6971941e93b7 100644 --- a/python/tvm/s_tir/schedule/instruction.py +++ b/python/tvm/s_tir/schedule/instruction.py @@ -32,7 +32,7 @@ INPUT_RV_TYPE = OUTPUT_RV_TYPE = ATTR_TYPE = Any -@_register_object("tir.InstructionKind") +@_register_object("s_tir.InstructionKind") class InstructionKind(Object): """Kind of an instruction, e.g. Split, Reorder, etc. Besides the name, every kind of instruction has its own properties, including: @@ -88,7 +88,7 @@ def get(name: str) -> "InstructionKind": return _ffi_api.InstructionKindGet(name) # type: ignore # pylint: disable=no-member -@_register_object("tir.Instruction") +@_register_object("s_tir.Instruction") class Instruction(Object): """Schedule instructions each corresponds to a schedule primitive diff --git a/python/tvm/s_tir/schedule/schedule.py b/python/tvm/s_tir/schedule/schedule.py index f1ffd2c9c6ff..c1e8a61efa13 100644 --- a/python/tvm/s_tir/schedule/schedule.py +++ b/python/tvm/s_tir/schedule/schedule.py @@ -36,7 +36,7 @@ class ScheduleError(TVMError): """Error that happens during TensorIR scheduling.""" -@_register_object("tir.LoopRV") +@_register_object("s_tir.LoopRV") class LoopRV(Object): """A random variable that refers to a loop""" @@ -47,7 +47,7 @@ def __init__(self) -> None: ) -@_register_object("tir.SBlockRV") +@_register_object("s_tir.SBlockRV") class SBlockRV(Object): """A random variable that refers to a block""" @@ -107,7 +107,7 @@ def _get_sblock_default_dtype(block: SBlock) -> str: return "int64" -@_register_object("tir.Schedule") +@_register_object("s_tir.Schedule") class Schedule(Object): """The user-facing schedule class diff --git a/python/tvm/s_tir/schedule/state.py b/python/tvm/s_tir/schedule/state.py index fe980b0848f1..e9f0f4309081 100644 --- a/python/tvm/s_tir/schedule/state.py +++ b/python/tvm/s_tir/schedule/state.py @@ -77,7 +77,7 @@ def _parse_enable_checks(enable_checks: bool) -> bool: return enable_checks -@register_object("tir.ScheduleState") +@register_object("s_tir.ScheduleState") class ScheduleState(Object): """The state of scheduling, which exposes a `Replace` method as the primary resort for all the scheduling primitives to manipulate the TensorIR. diff --git a/python/tvm/s_tir/schedule/trace.py b/python/tvm/s_tir/schedule/trace.py index 5796b7a225d5..121a7afbbbcf 100644 --- a/python/tvm/s_tir/schedule/trace.py +++ b/python/tvm/s_tir/schedule/trace.py @@ -54,7 +54,7 @@ def _json_from_tvm(obj): raise TypeError("Not supported type: " + str(type(obj))) -@_register_object("tir.Trace") +@_register_object("s_tir.Trace") class Trace(Object): """An execution trace of a scheduling program. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 78b8d6b804cb..f99da9dedb44 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -302,7 +302,7 @@ def find_anchor_sblock(mod: IRModule) -> SBlock: def has_if_then_else(stmt: Stmt) -> bool: - return tvm.ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) + return tvm.ffi.get_global_func("s_tir.schedule.HasIfThenElse")(stmt) def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index a7548c95b6cb..fffafa4da586 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -73,7 +73,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { /******** TuningRecord ********/ -TuningRecord::TuningRecord(tir::Trace trace, Workload workload, +TuningRecord::TuningRecord(s_tir::Trace trace, Workload workload, ffi::Optional> run_secs, ffi::Optional target, ffi::Optional> args_info) { @@ -91,8 +91,8 @@ bool WorkloadEqual::operator()(const Workload& a, const Workload& b) const { } MeasureCandidate TuningRecordNode::AsMeasureCandidate() const { - tir::Schedule sch = - tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail); + s_tir::Schedule sch = + s_tir::Schedule::Traced(workload->mod, -1, 0, s_tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(sch, false, nullptr); return MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true)); } @@ -133,7 +133,7 @@ bool TuningRecordNode::IsValid() const { } TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { - tir::Trace trace{ffi::UnsafeInit()}; + s_tir::Trace trace{ffi::UnsafeInit()}; ffi::Optional> run_secs; ffi::Optional target; ffi::Optional> args_info; @@ -161,10 +161,10 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w // Load json[0] => trace { auto json_trace = json_array->at(0).cast(); - tir::Schedule sch = - tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); - tir::Trace::ApplyJSONToSchedule(json_trace, sch); + s_tir::Schedule sch = + s_tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0, + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone); + s_tir::Trace::ApplyJSONToSchedule(json_trace, sch); trace = sch->trace().value(); } } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error @@ -194,14 +194,15 @@ ffi::Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, return records[0]; } -ffi::Optional DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, - const ffi::String& workload_name) { +ffi::Optional DatabaseNode::QuerySchedule(const IRModule& mod, + const Target& target, + const ffi::String& workload_name) { if (ffi::Optional opt_record = this->QueryTuningRecord(mod, target, workload_name)) { TuningRecord record = opt_record.value(); - tir::Schedule sch = - tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + s_tir::Schedule sch = + s_tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kDetail); record->trace->ApplyToSchedule(sch, false); return sch; } else { @@ -211,7 +212,7 @@ ffi::Optional DatabaseNode::QuerySchedule(const IRModule& mod, co ffi::Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, const ffi::String& workload_name) { - if (ffi::Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { + if (ffi::Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { return opt_sch.value()->mod(); } else { return std::nullopt; @@ -299,7 +300,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("meta_schedule.WorkloadAsJSON", &WorkloadNode::AsJSON) .def("meta_schedule.WorkloadFromJSON", &Workload::FromJSON) .def("meta_schedule.TuningRecord", - [](tir::Trace trace, Workload workload, ffi::Optional> run_secs, + [](s_tir::Trace trace, Workload workload, ffi::Optional> run_secs, ffi::Optional target, ffi::Optional> args_info) { return TuningRecord(trace, workload, run_secs, target, args_info); }) diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index 5825b6834b8f..63b7b347b4ce 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -28,7 +28,7 @@ class ScheduleFnDatabaseNode : public DatabaseNode { explicit ScheduleFnDatabaseNode(ffi::String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} - ffi::TypedFunction schedule_fn; + ffi::TypedFunction schedule_fn; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -41,7 +41,7 @@ class ScheduleFnDatabaseNode : public DatabaseNode { public: ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, const ffi::String& workload_name) final { - if (ffi::Optional sch = this->QuerySchedule(mod, target, workload_name)) { + if (ffi::Optional sch = this->QuerySchedule(mod, target, workload_name)) { return TuningRecord(sch.value()->trace().value(), /*workload=*/Workload(mod, 0), // /*run_secs=*/std::nullopt, // @@ -51,13 +51,13 @@ class ScheduleFnDatabaseNode : public DatabaseNode { return std::nullopt; } - ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, - const ffi::String& workload_name) final { - tir::Schedule sch = - tir::Schedule::Traced(WithAttr(mod, "task_name", workload_name), - /*rand_state=*/-1, - /*debug_mode=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { + s_tir::Schedule sch = + s_tir::Schedule::Traced(WithAttr(mod, "task_name", workload_name), + /*rand_state=*/-1, + /*debug_mode=*/0, + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kDetail); if (!schedule_fn(sch)) { return std::nullopt; } @@ -95,7 +95,7 @@ class ScheduleFnDatabaseNode : public DatabaseNode { } }; -Database Database::ScheduleFnDatabase(ffi::TypedFunction schedule_fn, +Database Database::ScheduleFnDatabase(ffi::TypedFunction schedule_fn, ffi::String mod_eq_name) { ObjectPtr n = ffi::make_object(mod_eq_name); n->schedule_fn = std::move(schedule_fn); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 9df517588215..7c5631071e51 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -31,7 +31,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using support::NDIntSet; @@ -308,10 +309,10 @@ Pass SimplifyForFeatureExtraction() { * \brief Create a list of passes that preprocesses the IR for feature extraction * \return The list of passes created */ -Sequential PassListForPerStoreFeature() { - return Sequential({ +tvm::transform::Sequential PassListForPerStoreFeature() { + return tvm::transform::Sequential({ tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_tensor_rewrite*/ true), - tir::transform::SimplifyForFeatureExtraction(), + s_tir::transform::SimplifyForFeatureExtraction(), s_tir::transform::LowerCrossThreadReduction(), s_tir::transform::LowerInitBlock(), s_tir::transform::PlanAndUpdateBufferAllocationLocation(), @@ -1356,7 +1357,7 @@ class PerStoreFeatureCollector : private StmtVisitor { std::unordered_map buffer_features_ = {}; }; -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { @@ -1382,14 +1383,14 @@ class PerStoreFeatureNode : public FeatureExtractorNode { } void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { - static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); + static transform::Sequential passes = s_tir::transform::PassListForPerStoreFeature(); mod = passes(std::move(mod)); - std::vector features = tir::PerStoreFeatureCollector::Collect( + std::vector features = s_tir::PerStoreFeatureCollector::Collect( is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); int n_features = features.size(); results->resize(n_features); for (int i = 0; i < n_features; ++i) { - const tir::Feature& feature = features[i]; + const s_tir::Feature& feature = features[i]; std::vector& result = (*results)[i]; result.reserve(feature_vector_length); feature.group1->Export(&result); @@ -1406,9 +1407,9 @@ class PerStoreFeatureNode : public FeatureExtractorNode { bool is_gpu = std::find(target_keys.begin(), target_keys.end(), "gpu") != target_keys.end(); std::vector results; results.resize(candidates.size()); - std::unique_ptr feature_group6 = nullptr; + std::unique_ptr feature_group6 = nullptr; if (extract_workload) { - feature_group6 = std::make_unique(tune_context->mod.value()); + feature_group6 = std::make_unique(tune_context->mod.value()); } auto f = [this, is_gpu, &feature_group6, &candidates, &results](int, int task_id) -> void { const auto& candidate = candidates[task_id]; @@ -1419,7 +1420,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode { feature_group6->Export(&feature); } } - results[task_id] = tir::utils::AsTensor(features, this->feature_vector_length); + results[task_id] = s_tir::utils::AsTensor(features, this->feature_vector_length); }; support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f); return results; @@ -1436,13 +1437,13 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; n->cache_line_bytes = cache_line_bytes; n->extract_workload = extract_workload; - n->feature_vector_length = tir::group1::Feature::kCount + // - tir::group2::Feature::SubFeature::kCount * buffers_per_store + // - arith_intensity_curve_num_samples + // - tir::group4::Feature::kCount + // - tir::group5::Feature::kCount; + n->feature_vector_length = s_tir::group1::Feature::kCount + // + s_tir::group2::Feature::SubFeature::kCount * buffers_per_store + // + arith_intensity_curve_num_samples + // + s_tir::group4::Feature::kCount + // + s_tir::group5::Feature::kCount; if (extract_workload) { - n->feature_vector_length += tir::group6::Feature::kCount; + n->feature_vector_length += s_tir::group6::Feature::kCount; } return FeatureExtractor(n); } diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 02aa6d898e27..6ec41ebaba8d 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -23,9 +23,9 @@ namespace tvm { namespace meta_schedule { -using tir::Instruction; -using tir::InstructionKind; -using tir::Trace; +using s_tir::Instruction; +using s_tir::InstructionKind; +using s_tir::Trace; /*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */ class MutateComputeLocationNode : public MutatorNode { @@ -75,24 +75,24 @@ class MutateComputeLocationNode : public MutatorNode { */ std::vector MutateComputeLocationNode::FindCandidates( const Trace& trace, TRandState* rand_state) { - tir::Schedule sch = tir::Schedule::Traced( // + s_tir::Schedule sch = s_tir::Schedule::Traced( // /*mod=*/LoadJSON(this->json_mod_).cast(), // /*rand_state=*/ForkSeed(rand_state), // /*debug_mode=*/0, // - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone); static InstructionKind inst_sample_compute_location = InstructionKind::Get("SampleComputeLocation"); std::vector candidates; - auto f_decision_provider = [&](const tir::Instruction& inst, // - const ffi::Array& inputs, // - const ffi::Array& attrs, // + auto f_decision_provider = [&](const s_tir::Instruction& inst, // + const ffi::Array& inputs, // + const ffi::Array& attrs, // const Any& decision) -> Any { if (inst->kind.same_as(inst_sample_compute_location)) { // Step 1. Extract the instruction input and the old decision. ICHECK_EQ(inputs.size(), 1); - tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); + tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); int old_decision = Downcast(decision)->value; // Step 2. Collect all the compute_at locations. @@ -122,8 +122,8 @@ ffi::Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandS if (candidates.empty()) { return std::nullopt; } - const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; - int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; + const Candidate& candidate = candidates[s_tir::SampleInt(rand_state, 0, candidates.size())]; + int loc = candidate.locs[s_tir::SampleInt(rand_state, 0, candidate.locs.size())]; return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); } diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index fa056b27444c..eba03ac249d5 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -24,7 +24,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Check if the instruction is annotation with `meta_schedule_parallel` @@ -38,7 +39,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { } ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); - return ann_key == attr::meta_schedule_parallel; + return ann_key == tir::attr::meta_schedule_parallel; } /*! @@ -82,7 +83,7 @@ std::vector> AnalyzeParallel(const ScheduleState& self, const ffi::String& block_name, const ffi::String& func_name, int64_t limit) { ffi::Array block_srefs = - tir::GetSBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); + GetSBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_srefs[0]); ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(ffi::GetRef(block)); @@ -148,14 +149,14 @@ std::vector GetNumFusedLoops(const std::vector>& loop_ return results; } -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { namespace meta_schedule { -using tir::Instruction; -using tir::Trace; +using s_tir::Instruction; +using s_tir::Trace; /*! \brief Create a Mutator that mutates the parallel extent */ class MutateParallelNode : public MutatorNode { @@ -217,17 +218,17 @@ struct MutateParallelNode::Candidate { */ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, MutateParallelNode::Candidate* candidate) { - using tir::InstructionNode; - using tir::SBlockRVNode; + using s_tir::InstructionNode; + using s_tir::SBlockRVNode; std::unordered_map get_sblock_insts; std::vector ann_insts; get_sblock_insts.reserve(trace->insts.size()); ann_insts.reserve(trace->insts.size()); for (const Instruction& inst : trace->insts) { - if (tir::IsAnnotateWithParallel(inst)) { + if (s_tir::IsAnnotateWithParallel(inst)) { ann_insts.push_back(inst.get()); } - if (const SBlockRVNode* block_rv = tir::GetInstGetSBlockOutput(inst)) { + if (const SBlockRVNode* block_rv = s_tir::GetInstGetSBlockOutput(inst)) { get_sblock_insts[block_rv] = inst.get(); } } @@ -235,10 +236,10 @@ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, if (n_ann_insts == 0) { return false; } - const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + const InstructionNode* ann_inst = ann_insts[s_tir::SampleInt(rand_state, 0, n_ann_insts)]; ICHECK_EQ(ann_inst->inputs.size(), 2); const InstructionNode* get_sblock_inst = - get_sblock_insts.at(Downcast(ann_inst->inputs[0]).get()); + get_sblock_insts.at(Downcast(ann_inst->inputs[0]).get()); ICHECK_EQ(get_sblock_inst->attrs.size(), 2); candidate->inst = ffi::GetRef(ann_inst); candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; @@ -254,21 +255,21 @@ ffi::Optional MutateParallelNode::Apply(const Trace& trace, TRandState* r return std::nullopt; } // Step 2. Replay the instructions to recover loop extents - tir::Schedule sch = tir::Schedule::Traced( // + s_tir::Schedule sch = s_tir::Schedule::Traced( // /*mod=*/LoadJSON(this->json_mod_).cast(), // /*rand_state=*/ForkSeed(rand_state), // /*debug_mode=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone); trace->ApplyToSchedule(sch, /*remove_postproc=*/true); // Step 3. Find all possible parallel plans - std::vector> loop_extent_prods = tir::AnalyzeParallel( + std::vector> loop_extent_prods = s_tir::AnalyzeParallel( sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); std::unordered_map> limit2plan; std::map, int64_t> plan2limit; for (const std::vector& prods : loop_extent_prods) { for (int64_t limit : prods) { if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { - std::vector plan = tir::GetNumFusedLoops(loop_extent_prods, limit); + std::vector plan = s_tir::GetNumFusedLoops(loop_extent_prods, limit); limit2plan[limit] = plan; plan2limit[plan] = limit; } @@ -276,7 +277,7 @@ ffi::Optional MutateParallelNode::Apply(const Trace& trace, TRandState* r } // Step 4. Remove the original plan and remove it std::vector original_plan = - tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); + s_tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); auto it = plan2limit.find(original_plan); if (it != plan2limit.end()) { plan2limit.erase(it); @@ -287,7 +288,7 @@ ffi::Optional MutateParallelNode::Apply(const Trace& trace, TRandState* r return std::nullopt; } it = plan2limit.begin(); - for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { + for (int i = 0, n = s_tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { ++it; } int64_t limit = it->second; @@ -296,7 +297,7 @@ ffi::Optional MutateParallelNode::Apply(const Trace& trace, TRandState* r insts.reserve(trace->insts.size()); for (const Instruction& inst : trace->insts) { if (inst.same_as(candidate.inst)) { - insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); + insts.push_back(s_tir::ReplaceAnnValue(candidate.inst, limit)); } else if (inst->kind->IsPostproc()) { break; } else { diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index ef9c30729485..aac21491c171 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -23,9 +23,9 @@ namespace tvm { namespace meta_schedule { -using tir::Instruction; -using tir::InstructionKind; -using tir::Trace; +using s_tir::Instruction; +using s_tir::InstructionKind; +using s_tir::Trace; /*! \brief A mutator that mutates the thread binding factor decision of SampleCategorical */ class MutateThreadBindingNode : public MutatorNode { @@ -82,15 +82,15 @@ class MutateThreadBindingNode : public MutatorNode { */ std::vector MutateThreadBindingNode::FindCandidates( const Trace& trace, TRandState* rand_state) { - using tir::InstructionNode; + using s_tir::InstructionNode; static InstructionKind inst_sample_categorical = InstructionKind::Get("SampleCategorical"); static InstructionKind inst_split = InstructionKind::Get("Split"); static InstructionKind inst_bind = InstructionKind::Get("Bind"); std::vector candidates; - std::unordered_map sample_insts; - std::unordered_map sampled_split_insts; + std::unordered_map sample_insts; + std::unordered_map sampled_split_insts; std::vector bind_insts; auto is_split_by_sample = [&sample_insts](const Instruction& inst) -> bool { @@ -112,7 +112,7 @@ std::vector MutateThreadBindingNode::FindCan ICHECK_EQ(inst->attrs.size(), 1); if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; - return sampled_split_insts.find(Downcast(inst->inputs[0]).get()) != + return sampled_split_insts.find(Downcast(inst->inputs[0]).get()) != sampled_split_insts.end(); }; @@ -124,7 +124,7 @@ std::vector MutateThreadBindingNode::FindCan } else if (is_split_by_sample(inst)) { CHECK_EQ(inst->outputs.size(), 2); // Only consider the inner loop, which can be bound to threadIdx.x - const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode); + const s_tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], s_tir::LoopRVNode); sampled_split_insts[var_rv] = inst.get(); } else if (is_thread_binding_by_sample(inst)) { bind_insts.push_back(inst.get()); @@ -132,7 +132,7 @@ std::vector MutateThreadBindingNode::FindCan } for (const InstructionNode* bind_inst : bind_insts) { - const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode); + const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], s_tir::LoopRVNode); auto split_it = sampled_split_insts.find(loop_rv); ICHECK(split_it != sampled_split_insts.end()); const InstructionNode* split_inst = split_it->second; @@ -157,10 +157,10 @@ ffi::Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandSta if (candidates.empty()) { return std::nullopt; } - Candidate candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + Candidate candidate = candidates[s_tir::SampleInt(rand_state, 0, candidates.size())]; // Remove the current decision candidate.probs.erase(candidate.probs.begin() + candidate.decision); - int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + int result = s_tir::MakeMultinomialSampler(rand_state, candidate.probs)(); if (result >= candidate.decision) { result += 1; } diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index e2f3689d2854..c806772165a9 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -26,9 +26,9 @@ namespace tvm { namespace meta_schedule { -using tir::Instruction; -using tir::InstructionKind; -using tir::Trace; +using s_tir::Instruction; +using s_tir::InstructionKind; +using s_tir::Trace; /*! * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers @@ -119,7 +119,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, ICHECK_EQ(inst->attrs.size(), 1); ICHECK_EQ(inst->inputs.size(), 2); if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { - const auto* ann_val = inst->inputs[1].as(); + const auto* ann_val = inst->inputs[1].as(); ICHECK(ann_val); annotated.insert(ann_val); } @@ -197,11 +197,11 @@ ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, int x, y; // select source while (true) { - x = tir::SampleInt(rand_state, 0, n_splits); + x = s_tir::SampleInt(rand_state, 0, n_splits); if (tiles[x] <= 1) { continue; } - y = tir::SampleInt(rand_state, 0, n_splits - 1); + y = s_tir::SampleInt(rand_state, 0, n_splits - 1); if (y >= x) { ++y; } @@ -209,7 +209,7 @@ ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, // Step 2. Choose the divide factor int64_t divide_factor; if (y != n_splits - 1) { - divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())]; + divide_factor = factors[s_tir::SampleInt(rand_state, 1, factors.size())]; } else { int64_t limit = Downcast(inst->attrs[1])->value; int max_factor_index = static_cast(factors.size()) - 1; @@ -225,7 +225,7 @@ ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, // Failed on this dst_idx, try next one. continue; } - divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)]; + divide_factor = factors[s_tir::SampleInt(rand_state, 1, max_factor_index + 1)]; } tiles[x] /= divide_factor; tiles[y] *= divide_factor; @@ -240,7 +240,7 @@ ffi::Optional MutateSampleVectorize(const Trace& trace, Instruction inst, std::vector probs = support::AsVector(Downcast>(inst->attrs[1])); probs.erase(probs.begin() + original_decision); - int result = tir::MakeMultinomialSampler(rand_state, probs)(); + int result = s_tir::MakeMultinomialSampler(rand_state, probs)(); if (result >= original_decision) { result += 1; } @@ -259,7 +259,7 @@ ffi::Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* r if (size_a == 0 && size_b == 0) { return std::nullopt; } - int n = tir::SampleInt(rand_state, 0, size_a + size_b); + int n = s_tir::SampleInt(rand_state, 0, size_a + size_b); if (n < size_a) { return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n], rand_state); diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index dab987708238..87e7dc43b716 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Check if an instruction is annotate with @@ -36,18 +37,18 @@ bool IsAnnotateWithUnroll(const Instruction& inst) { } ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); - return ann_key == attr::meta_schedule_unroll_explicit || - ann_key == attr::meta_schedule_unroll_implicit; + return ann_key == tir::attr::meta_schedule_unroll_explicit || + ann_key == tir::attr::meta_schedule_unroll_implicit; } -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { namespace meta_schedule { -using tir::Instruction; -using tir::Trace; +using s_tir::Instruction; +using s_tir::Trace; /*! \brief Create a Mutator that mutates auto unroll step */ class MutateUnrollNode : public MutatorNode { @@ -90,8 +91,8 @@ struct MutateUnrollNode::Candidate { */ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, MutateUnrollNode::Candidate* candidate) { - using tir::InstructionKind; - using tir::InstructionNode; + using s_tir::InstructionKind; + using s_tir::InstructionNode; static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); std::unordered_map sample_insts; std::vector ann_insts; @@ -110,7 +111,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, if (n_ann_insts == 0) { return false; } - const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + const InstructionNode* ann_inst = ann_insts[s_tir::SampleInt(rand_state, 0, n_ann_insts)]; ICHECK_EQ(ann_inst->inputs.size(), 2); const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode); ICHECK(sample_insts.count(var_rv)); @@ -133,7 +134,7 @@ ffi::Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* ran return std::nullopt; } candidate.probs.erase(candidate.probs.begin() + candidate.decision); - int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + int result = s_tir::MakeMultinomialSampler(rand_state, candidate.probs)(); if (result >= candidate.decision) { result += 1; } diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index fd8fe45bf185..7ddd8ac34ad6 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -29,8 +29,8 @@ void PyMutatorNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -ffi::Optional PyMutatorNode::Apply( - const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) { +ffi::Optional PyMutatorNode::Apply( + const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) { ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; return f_apply(trace, *rand_state); } @@ -98,7 +98,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("meta_schedule.MutatorInitializeWithTuneContext", &MutatorNode::InitializeWithTuneContext) .def("meta_schedule.MutatorApply", - [](Mutator self, tir::Trace trace, TRandState seed) -> ffi::Optional { + [](Mutator self, s_tir::Trace trace, TRandState seed) -> ffi::Optional { TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 956e5ddcb5a6..5863ad2bf896 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Check if an IRModule has any async strided mem copies. */ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { @@ -114,7 +115,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { ffi::Map input_iters = ffi::Map(); }; -} // namespace tir +} // namespace s_tir namespace meta_schedule { @@ -128,7 +129,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { this->target = context->target.value(); } // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final { + bool Apply(const s_tir::Schedule& sch) final { IRModule mod = sch->mod(); for (const auto& kv : mod->functions) { const GlobalVar& g_var = kv.first; @@ -161,7 +162,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { } catch (const dmlc::Error& e) { return false; } - if (tir::AsyncStridedMemCopyFinder::Find(lowered)) { + if (s_tir::AsyncStridedMemCopyFinder::Find(lowered)) { return false; } } diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index df7344455e6d..60602bf87d4c 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Check if an IRModule has any dynamic loop. */ struct DynamicExtentFinder : private StmtVisitor { @@ -58,7 +59,7 @@ struct DynamicExtentFinder : private StmtVisitor { bool found_ = false; }; -} // namespace tir +} // namespace s_tir namespace meta_schedule { @@ -68,7 +69,9 @@ class DisallowDynamicLoopNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + bool Apply(const s_tir::Schedule& sch) final { + return !s_tir::DynamicExtentFinder::Find(sch->mod()); + } // Inherited from PostprocNode Postproc Clone() const { ObjectPtr n = ffi::make_object(*this); diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 41557830afb6..b265178114a7 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -29,7 +29,7 @@ void PyPostprocNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -bool PyPostprocNode::Apply(const tir::Schedule& sch) { +bool PyPostprocNode::Apply(const s_tir::Schedule& sch) { ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; return f_apply(sch); } diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 6b2fa17bd20b..80fefec43ae6 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Parse instruction: sch.bind(..., axis) @@ -61,7 +62,7 @@ ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& in ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); - if (ann_key != attr::meta_schedule_cooperative_fetch) { + if (ann_key != tir::attr::meta_schedule_cooperative_fetch) { return std::nullopt; } *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; @@ -82,7 +83,7 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); - return ann_key == attr::warp_execution; + return ann_key == tir::attr::warp_execution; } size_t GetMaxUsedDtypeBytes(SBlock block) { @@ -108,7 +109,7 @@ size_t GetMaxUsedDtypeBytes(SBlock block) { return max_bytes; } -} // namespace tir +} // namespace s_tir namespace meta_schedule { @@ -133,7 +134,7 @@ class RewriteCooperativeFetchNode : public PostprocNode { } // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final; + bool Apply(const s_tir::Schedule& sch) final; Postproc Clone() const { ObjectPtr n = ffi::make_object(*this); @@ -147,37 +148,37 @@ class RewriteCooperativeFetchNode : public PostprocNode { int thread_warp_size_ = -1; }; -bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { - tir::Trace trace = sch->trace().value(); +bool RewriteCooperativeFetchNode::Apply(const s_tir::Schedule& sch) { + s_tir::Trace trace = sch->trace().value(); int64_t thread_extent_x = -1; int64_t thread_extent_y = -1; int64_t vector_lane = 1; std::vector> tasks; - for (const tir::Instruction& inst : trace->insts) { + for (const s_tir::Instruction& inst : trace->insts) { if (ffi::Optional new_thread_extent = - tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { + s_tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { thread_extent_x = new_thread_extent.value()->value; continue; } if (ffi::Optional new_thread_extent = - tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { + s_tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { thread_extent_y = new_thread_extent.value()->value; continue; } - if (tir::ParseWarpExecutionAnn(sch, inst)) { + if (s_tir::ParseWarpExecutionAnn(sch, inst)) { thread_extent_x = thread_warp_size_; continue; } - ffi::Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); + ffi::Optional opt_block_rv = s_tir::ParseAnnotate(sch, inst, &vector_lane); if (!opt_block_rv.defined()) { continue; } auto task = [thread_extent_x, thread_extent_y, vector_lane, sch, block = opt_block_rv.value()]() mutable -> void { sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch); - tir::LoopRV fused = sch->GetLoops(block).back(); + s_tir::LoopRV fused = sch->GetLoops(block).back(); int64_t fused_extent = -1; - if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) { + if (const int64_t* extent = s_tir::GetLoopIntExtent(sch->Get(fused).get())) { fused_extent = *extent; } else { return; @@ -189,34 +190,34 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { // vectorization of 64 bit values does not work well on CUDA. // TODO(masahi, vinx13): Decouple epilogue fusion computation and shared to global store, so // that we can always vectorize the latter. - if (tir::GetMaxUsedDtypeBytes(sch->Get(block)) > 4) { + if (s_tir::GetMaxUsedDtypeBytes(sch->Get(block)) > 4) { vector_lane = 1; } if (thread_extent_y != -1) { if (vector_lane > 1) { - ffi::Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x), // - Integer(vector_lane)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_y), // + Integer(thread_extent_x), // + Integer(vector_lane)}); sch->Vectorize(split[3]); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } else { - ffi::Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_y), // + Integer(thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } } else { if (vector_lane > 1) { - ffi::Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_x), // - Integer(vector_lane)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_x), // + Integer(vector_lane)}); sch->Vectorize(split[2]); sch->Bind(split[1], "threadIdx.x"); } else { - ffi::Array split = + ffi::Array split = sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); sch->Bind(split[1], "threadIdx.x"); } diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 88fe0419b5ca..7f5890594c69 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -24,7 +24,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Collect the block and index where the buffer is read. @@ -120,7 +121,7 @@ class LayoutFreeBufferCollector : public StmtVisitor { ffi::Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { // Only rewrite PrimFuncs with attr "layout_free_buffers" ffi::Array layout_free_buffer_index = - func->GetAttr(attr::layout_free_buffers, ffi::Array()).value(); + func->GetAttr(tir::attr::layout_free_buffers, ffi::Array()).value(); ffi::Array layout_free_buffers; for (const Integer& index : layout_free_buffer_index) { @@ -185,7 +186,7 @@ bool RewriteLayout(const Schedule& sch) { std::vector> results; auto add_layout_rewrite_block = [&sch](SBlockRV consumer_block_rv, int buffer_index) { SBlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); - sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, true); + sch->Annotate(rewrite_block_rv, tir::attr::meta_schedule_layout_rewrite_preproc, true); }; for (const auto& [g_var, base_func] : sch->mod()->functions) { @@ -242,7 +243,7 @@ bool RewriteLayout(const Schedule& sch) { return true; } -} // namespace tir +} // namespace s_tir namespace meta_schedule { /*! \brief Layout Rewrite. */ @@ -252,9 +253,9 @@ class RewriteLayoutNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final { + bool Apply(const s_tir::Schedule& sch) final { try { - return tir::RewriteLayout(sch); + return s_tir::RewriteLayout(sch); } catch (const std::runtime_error& e) { return false; } diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index b2cbf2701043..981af265bb98 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Check whether the loop has any annotation @@ -106,22 +107,22 @@ bool ParseAnnotation(const SBlock& block, ParsedAnnotation* parsed) { bool found = false; *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; for (const auto& ann : block->annotations) { - if (ann.first == attr::meta_schedule_parallel) { + if (ann.first == tir::attr::meta_schedule_parallel) { found = true; if (auto opt_int_imm = ann.second.try_cast()) { parsed->max_parallel_extent = (*opt_int_imm)->value; } - } else if (ann.first == attr::meta_schedule_vectorize) { + } else if (ann.first == tir::attr::meta_schedule_vectorize) { found = true; if (auto opt_int_imm = ann.second.try_cast()) { parsed->max_vectorize_extent = (*opt_int_imm)->value; } - } else if (ann.first == attr::meta_schedule_unroll_explicit) { + } else if (ann.first == tir::attr::meta_schedule_unroll_explicit) { found = true; if (auto opt_int_imm = ann.second.try_cast()) { parsed->unroll_explicit = (*opt_int_imm)->value; } - } else if (ann.first == attr::meta_schedule_unroll_implicit) { + } else if (ann.first == tir::attr::meta_schedule_unroll_implicit) { found = true; if (auto opt_int_imm = ann.second.try_cast()) { parsed->unroll_implicit = (*opt_int_imm)->value; @@ -134,16 +135,16 @@ bool ParseAnnotation(const SBlock& block, ParsedAnnotation* parsed) { void RemoveParsedAnn(const Schedule& sch, const SBlockRV& block_rv, const ParsedAnnotation& parsed) { if (parsed.max_parallel_extent != -1) { - sch->Unannotate(block_rv, attr::meta_schedule_parallel); + sch->Unannotate(block_rv, tir::attr::meta_schedule_parallel); } if (parsed.max_vectorize_extent != -1) { - sch->Unannotate(block_rv, attr::meta_schedule_vectorize); + sch->Unannotate(block_rv, tir::attr::meta_schedule_vectorize); } if (parsed.unroll_explicit != -1) { - sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); + sch->Unannotate(block_rv, tir::attr::meta_schedule_unroll_explicit); } if (parsed.unroll_implicit != -1) { - sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); + sch->Unannotate(block_rv, tir::attr::meta_schedule_unroll_implicit); } } @@ -400,44 +401,45 @@ void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const return; } - sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); - sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); + sch->Annotate(loop, tir::attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, tir::attr::pragma_unroll_explicit, + IntImm(DataType::Int(32), unroll_explicit)); } -} // namespace tir +} // namespace s_tir namespace meta_schedule { -using tir::Schedule; +using s_tir::Schedule; class RewriteParallelVectorizeUnrollNode : public PostprocNode { public: void InitializeWithTuneContext(const TuneContext& context) final {} bool Apply(const Schedule& sch) final { - tir::ParsedAnnotation parsed_root; - tir::SBlockRV root_rv{ffi::UnsafeInit()}; - while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { - for (tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { - ffi::Array loop_rvs = sch->GetLoops(block_rv); + s_tir::ParsedAnnotation parsed_root; + s_tir::SBlockRV root_rv{ffi::UnsafeInit()}; + while (s_tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { + for (s_tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { + ffi::Array loop_rvs = sch->GetLoops(block_rv); if (loop_rvs.empty()) { continue; } - tir::ParsedAnnotation parsed = parsed_root; - tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + s_tir::ParsedAnnotation parsed = parsed_root; + s_tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); const int loops_num = loop_rvs.size(); try { if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { // Fuse, split, vectorize and parallelize - tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); + s_tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); } else { // Parallel if (parsed.num_parallel_loops > 0) { - tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + s_tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); } // Vectorize if (parsed.num_vectorize_loops > 0) { - tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + s_tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); } } // AutoUnroll @@ -445,9 +447,9 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); int unroll_explicit = parsed.unroll_explicit != -1; int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; - tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); + s_tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); } - } catch (const tir::ScheduleError& e) { + } catch (const s_tir::ScheduleError& e) { DLOG(WARNING) << "Failed to apply parallelization/vectorization: " << e.what(); return false; } diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index f65ca90e7783..a7a0e6ca2b04 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief The visitor that finds all the reduction block to be decomposed */ struct ReductionBlockFinder : private StmtVisitor { @@ -102,7 +103,7 @@ int FindDecomposePoint(const StmtSRef& block_sref) { return -1; } -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { @@ -119,7 +120,7 @@ class RewriteReductionBlockNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final; + bool Apply(const s_tir::Schedule& sch) final; Postproc Clone() const { ObjectPtr n = ffi::make_object(*this); @@ -130,29 +131,29 @@ class RewriteReductionBlockNode : public PostprocNode { RewriteReductionBlockNode, PostprocNode); }; -bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { +bool RewriteReductionBlockNode::Apply(const s_tir::Schedule& sch) { for (;;) { std::vector> results = - tir::ReductionBlockFinder::Find(sch->state()); + s_tir::ReductionBlockFinder::Find(sch->state()); int rewritten = 0; for (const auto& kv : results) { const tir::StmtSRef& block_sref = kv.first; const ffi::String& global_var_name = kv.second; - int decompose_point = tir::FindDecomposePoint(block_sref); + int decompose_point = s_tir::FindDecomposePoint(block_sref); if (decompose_point == -1) { continue; } - tir::SBlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); - ffi::Array loop_rvs = sch->GetLoops(block_rv); - tir::SBlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + s_tir::SBlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + ffi::Array loop_rvs = sch->GetLoops(block_rv); + s_tir::SBlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); // Rewrite auto tensorization related annotations - if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize) + if (s_tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize) .has_value()) { // Remove tensorization annotation as it shouldn't be propagated to the init block. sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); ffi::Optional tensorize_init = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); + s_tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); // The annotation of tensorization of the init statement should be moved to the init block // after 'DecomposeReduction'. // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is std::nullopt. diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index e3490af29072..bc0b07e590d3 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -26,21 +26,21 @@ namespace tvm { namespace meta_schedule { -using tir::LoopRV; -using tir::SBlockRV; +using s_tir::LoopRV; +using s_tir::SBlockRV; void CollectTensorizationJobs( - const tir::Schedule& sch, const ffi::String& func_name, const tir::PrimFuncNode* func, + const s_tir::Schedule& sch, const ffi::String& func_name, const tir::PrimFuncNode* func, bool vectorize_init_loop, - std::vector>>* jobs) { + std::vector>>* jobs) { tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); std::string block_name = block_sref->StmtAs()->name_hint; if (ffi::Optional intrin_name = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + s_tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { if (intrin_name.value() != "") { - jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::SBlockRV block) { + jobs->emplace_back(block_name, func_name, [sch, intrin_name](s_tir::SBlockRV block) { try { sch->Tensorize(block, intrin_name.value()); } catch (const std::exception& e) { @@ -48,7 +48,7 @@ void CollectTensorizationJobs( } }); } else if (block_name.find("init") && vectorize_init_loop) { - jobs->emplace_back(block_name, func_name, [sch](tir::SBlockRV block) { + jobs->emplace_back(block_name, func_name, [sch](s_tir::SBlockRV block) { ffi::Array child_blocks = sch->GetChildBlocks(block); ICHECK(child_blocks.size() == 1); ffi::Array init_loops = sch->GetLoops(child_blocks[0]); @@ -70,7 +70,7 @@ class RewriteTensorizeNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} - bool Apply(const tir::Schedule& sch) final; + bool Apply(const s_tir::Schedule& sch) final; Postproc Clone() const { ObjectPtr n = ffi::make_object(*this); @@ -83,9 +83,9 @@ class RewriteTensorizeNode : public PostprocNode { PostprocNode); }; -bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { +bool RewriteTensorizeNode::Apply(const s_tir::Schedule& sch) { // The rewriting jobs, 3-tuple (block_name, func_name, job_func) - std::vector>> jobs; + std::vector>> jobs; for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 08580830965b..3bb11e6f7d85 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Find all the blocks that are not bound */ class UnboundBlockFinder : private StmtVisitor { @@ -77,7 +78,7 @@ class UnboundBlockFinder : private StmtVisitor { ffi::String global_var_name_; }; -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { @@ -97,7 +98,7 @@ class RewriteUnboundBlockNode : public PostprocNode { } // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final; + bool Apply(const s_tir::Schedule& sch) final; Postproc Clone() const { ObjectPtr n = ffi::make_object(*this); @@ -118,17 +119,17 @@ class RewriteUnboundBlockNode : public PostprocNode { PostprocNode); }; -bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { - using tir::ExprRV; - using tir::LoopRV; - using tir::SBlockRV; - using tir::Schedule; +bool RewriteUnboundBlockNode::Apply(const s_tir::Schedule& sch) { + using s_tir::ExprRV; + using s_tir::LoopRV; + using s_tir::SBlockRV; + using s_tir::Schedule; ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { return Integer(std::min(t, max_extent)); }; std::vector> unbound_blocks = - tir::UnboundBlockFinder::Find(sch->state()); + s_tir::UnboundBlockFinder::Find(sch->state()); for (const auto& kv : unbound_blocks) { tir::StmtSRef block_sref = kv.first; ffi::String global_var_name = kv.second; diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index f1ff28b071ff..1532fb0b156e 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -23,7 +23,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class ThreadExtentChecker : private StmtVisitor { public: @@ -71,13 +72,13 @@ class ThreadExtentChecker : private StmtVisitor { void VisitStmt_(const SBlockNode* block) { int old_thread_idx_x = thread_idx_x; - if (block->annotations.count(attr::warp_execution)) { + if (block->annotations.count(tir::attr::warp_execution)) { thread_idx_x = thread_warp_size_; } if (ffi::Optional low_inclusive = - GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { + GetAnn(block, tir::attr::meta_schedule_thread_extent_low_inclusive)) { if (ffi::Optional high_inclusive = - GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { + GetAnn(block, tir::attr::meta_schedule_thread_extent_high_inclusive)) { int64_t low = low_inclusive.value()->value; int64_t high = high_inclusive.value()->value; int64_t thread_extent_product = thread_idx_x * thread_idx_y * thread_idx_z; @@ -96,7 +97,7 @@ class ThreadExtentChecker : private StmtVisitor { int thread_warp_size_ = -1; }; -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { @@ -142,13 +143,13 @@ class VerifyGPUCodeNode : public PostprocNode { return true; } - bool Apply(const tir::Schedule& sch) final { + bool Apply(const s_tir::Schedule& sch) final { IRModule mod = sch->mod(); for (const auto& kv : mod->functions) { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { - if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { + if (!s_tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { return false; } IRModule lowered{ffi::UnsafeInit()}; diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index f0fe8be1c1c9..b337452682e1 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -43,7 +43,7 @@ class VerifyVTCMLimitNode : public PostprocNode { return true; } - bool Apply(const tir::Schedule& sch) final { + bool Apply(const s_tir::Schedule& sch) final { IRModule mod = sch->mod(); IRModule lowered{nullptr}; auto pass_list = tir::GetVTCMCompactionPasses(); diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index dfa5a3969118..eb0fede257dd 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -25,9 +25,14 @@ namespace tvm { namespace meta_schedule { using namespace tvm::tir; +using s_tir::ExprRV; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; -static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::SBlockRV block, - std::vector tiled, std::vector unrolled) { +static ffi::Array ScheduleDataPack(s_tir::Schedule sch, s_tir::SBlockRV block, + std::vector tiled, + std::vector unrolled) { using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index d80fefc6cc5d..6ef9016a5d01 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -17,8 +17,8 @@ * under the License. */ #include +#include #include -#include #include #include @@ -30,6 +30,17 @@ namespace tvm { namespace meta_schedule { using namespace tvm::tir; +using s_tir::ExprRV; +using s_tir::GetLoopIterType; +using s_tir::GetLoops; +using s_tir::GetThreadScope; +using s_tir::HasBeenMultiLevelTiled; +using s_tir::IsBlockIdx; +using s_tir::IsSingleStmt; +using s_tir::IsThreadIdx; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; std::function MakeFactorSampler(Schedule sch, ffi::Array thread_extents) { return [sch = std::move(sch), @@ -84,18 +95,17 @@ ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_thread } } -void BindBlockThreadIdx(tir::Schedule sch, tir::SBlockRV block_rv, // +void BindBlockThreadIdx(Schedule sch, SBlockRV block_rv, // int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor) { - using namespace tvm::tir; + std::function get_factor) { StmtSRef block_sref = sch->GetSRef(block_rv); if (block_sref->parent == nullptr) { return; } - if (tir::HasBeenMultiLevelTiled(block_sref)) { + if (HasBeenMultiLevelTiled(block_sref)) { return; } - ffi::Array loops = tir::GetLoops(block_sref); + ffi::Array loops = GetLoops(block_sref); int n = loops.size(); int i_block_idx = -1; int i_thread_idx = -1; diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 62d8c767e293..c78a9e3bfb33 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -28,9 +28,14 @@ namespace tvm { namespace meta_schedule { using namespace tvm::tir; +using s_tir::ExprRV; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; -static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::SBlockRV block, - std::vector tiled, std::vector unrolled) { +static ffi::Array ScheduleDataPack(s_tir::Schedule sch, s_tir::SBlockRV block, + std::vector tiled, + std::vector unrolled) { // This method is used for NHWC layout only. Will likely be refactored into a more schedule using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); diff --git a/src/meta_schedule/schedule/generic/winograd.cc b/src/meta_schedule/schedule/generic/winograd.cc index a3c75f33cb53..5d4c36387acc 100644 --- a/src/meta_schedule/schedule/generic/winograd.cc +++ b/src/meta_schedule/schedule/generic/winograd.cc @@ -22,6 +22,10 @@ namespace tvm { namespace meta_schedule { using namespace tvm::tir; +using s_tir::ExprRV; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; /*! * \brief Get the producer block of a given block. diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 2b730b0138a2..18a5d274de32 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -36,7 +36,7 @@ class AddRFactorNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv); + ffi::Array Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv); // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { @@ -77,8 +77,8 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, return ScheduleRule(n); } -ffi::Array AddRFactorNode::Apply(const tir::Schedule& sch, - const tir::SBlockRV& block_rv) { +ffi::Array AddRFactorNode::Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) { tir::StmtSRef block_sref = sch->GetSRef(block_rv); if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, max_parallel_basic_)) { @@ -86,28 +86,28 @@ ffi::Array AddRFactorNode::Apply(const tir::Schedule& sch, } // Make a copy of the original schedule. - tir::Schedule ori_sch = sch->Copy(); + s_tir::Schedule ori_sch = sch->Copy(); ori_sch->Seed(sch->ForkSeed()); // Reorder the loop axes if reduction loops are not innermost. // After the reordering, fuse all the reduction loops. size_t num_spatial_loops; - tir::LoopRV fused_reduce_loop; + s_tir::LoopRV fused_reduce_loop; ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Split the fused reduction loop. - ffi::Array factors = + ffi::Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); - ffi::Array split_loops = + ffi::Array split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); - ffi::Array res; - for (const tir::LoopRV& split_loop : split_loops) { - tir::Schedule sch_tmp = sch->Copy(); + ffi::Array res; + for (const s_tir::LoopRV& split_loop : split_loops) { + s_tir::Schedule sch_tmp = sch->Copy(); sch_tmp->Seed(sch->ForkSeed()); try { - const tir::SBlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); - ffi::Array axes = sch_tmp->GetLoops(block_rf); + const s_tir::SBlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); + ffi::Array axes = sch_tmp->GetLoops(block_rf); ICHECK_GT(axes.size(), num_spatial_loops); // Annotate that the rfactor block, which is now the producer of the original block, needs to diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index bdfd9b525690..57290cb98803 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -36,18 +36,19 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final { CHECK(this->target_.defined()) << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target."; ffi::Array keys = this->target_.value()->keys; if (ffi::Optional ann = - tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { + s_tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { if (ann.value() != "None") { for (const ffi::String& key : keys) { if (const auto custom_schedule_fn = tvm::ffi::Function::GetGlobal(GetCustomRuleName(ann.value(), key))) { - ffi::Array result = - (*custom_schedule_fn)(sch, block_rv).cast>(); + ffi::Array result = + (*custom_schedule_fn)(sch, block_rv).cast>(); return result; } } diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 2fbf013e82da..6f14b8a5d474 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -40,7 +40,8 @@ class AutoBindNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final; + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { @@ -63,8 +64,8 @@ class AutoBindNode : public ScheduleRuleNode { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.AutoBind", AutoBindNode, ScheduleRuleNode); }; -ffi::Array AutoBindNode::Apply(const tir::Schedule& sch, - const tir::SBlockRV& block_rv) { +ffi::Array AutoBindNode::Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) { ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = MakeFactorSampler(sch, this->thread_extents_); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 5c065e6b4738..65aabbf6023b 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -23,6 +23,19 @@ namespace tvm { namespace meta_schedule { +using s_tir::AnalyzeReadWritePattern; +using s_tir::CanComputeInline; +using s_tir::CanReverseComputeInline; +using s_tir::GetAnn; +using s_tir::GetConsumers; +using s_tir::GetProducers; +using s_tir::GetRootPrimFunc; +using s_tir::GetSBlockRealize; +using s_tir::HasIfThenElse; +using s_tir::HasOp; +using s_tir::IsSpatialPrimFunc; +using s_tir::ScheduleState; + /*! \brief The type of inline to be performed on a specific block */ enum class InlineType : int32_t { /*! \brief No inline opportunity */ @@ -33,7 +46,7 @@ enum class InlineType : int32_t { kInlineIntoProducer = 2, }; -bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sref) { +bool IsInSpatialPrimFunc(const s_tir::Schedule& sch, const tir::StmtSRef& block_sref) { using namespace tvm::tir; const StmtSRefNode* sref = block_sref.get(); for (; sref->parent != nullptr; sref = sref->parent) { @@ -46,13 +59,14 @@ bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sr class AutoInlineNode : public ScheduleRuleNode { public: /*! \brief Checks if the specific block should be inlined */ - inline InlineType CheckInline(const tir::Schedule& sch, const tir::SBlockRV& block_rv); + inline InlineType CheckInline(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv); // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final { InlineType inline_type = CheckInline(sch, block_rv); if (inline_type == InlineType::kInlineIntoConsumer) { sch->ComputeInline(block_rv); @@ -98,8 +112,8 @@ class AutoInlineNode : public ScheduleRuleNode { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.AutoInline", AutoInlineNode, ScheduleRuleNode); }; -inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, - const tir::SBlockRV& block_rv) { +inline InlineType AutoInlineNode::CheckInline(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) { using namespace tvm::tir; StmtSRef block_sref = sch->GetSRef(block_rv); bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref); @@ -143,12 +157,12 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } // Cond 6. The block is disallowed for auto inline if (ffi::Optional ann = - tir::GetAnn(block_sref, tir::attr::meta_schedule_inline_rule)) { + s_tir::GetAnn(block_sref, tir::attr::meta_schedule_inline_rule)) { if (ann.value() == "disable") return InlineType::kNoInline; } // Last cond: Check inline into the consumers or the spatial producer - tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, - /*require_stage_pipeline=*/false); + tir::StmtSRef scope_block = s_tir::GetScopeRoot(sch->state(), block_sref, + /*require_stage_pipeline=*/false); if (into_consumer) { ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { @@ -158,7 +172,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, if (into_producer) { ffi::Array producer_srefs = GetProducers(state, block_sref); if (producer_srefs.size() == 1 && - tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && + s_tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && CanReverseComputeInline(state, block_sref) && !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize) .has_value()) { @@ -205,7 +219,8 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { public: void InitializeWithTuneContext(const TuneContext& context) final {} - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final { // Look for a block of the form // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) { // reads([]) @@ -216,7 +231,8 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { if (block->reads.size() == 0 && block->writes.size() == 1 && block->writes[0]->buffer->shape.size() == 0) { auto sref = sch->GetSRef(block_rv); - if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) { + if (!s_tir::IsOutputBlock(sch->state(), sref, + s_tir::GetScopeRoot(sch->state(), sref, true))) { sch->ComputeInline(block_rv); } } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 1d70f21199a4..43dfafd6f53a 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -49,7 +49,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final { // Step 0. Check the conditions of this rule. if (max_threads_per_block == -1 || warp_size == -1) { return {sch}; @@ -61,7 +62,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Step 1. Make a copy of the original schedule. The new copy is used for scheduling. - tir::Schedule tmp_sch = sch->Copy(); + s_tir::Schedule tmp_sch = sch->Copy(); tmp_sch->Seed(sch->ForkSeed()); // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the @@ -77,7 +78,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); ffi::Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); - tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); + s_tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_sblock.defined()); ICHECK(target_loop.defined()); @@ -89,7 +90,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // the loop before binding. // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. if (!InThreadScope(tmp_sch, target_sblock)) { - const ffi::Array& split_res = + const ffi::Array& split_res = tmp_sch->Split(tgt_block_innermost_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); if (tgt_block_innermost_loop.same_as(target_loop)) { @@ -107,10 +108,10 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering, // fuse all the reduction loops. size_t num_spatial_loops; - tir::LoopRV fused_reduce_loop; + s_tir::LoopRV fused_reduce_loop; ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. - const ffi::Array& split_res = + const ffi::Array& split_res = tmp_sch->Split(fused_reduce_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); @@ -131,12 +132,12 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param block The block to be checked * \return A boolean indicating whether the block is in thread scope. */ - bool InThreadScope(const tir::Schedule& sch, const tir::SBlockRV& block) { - const ffi::Array& axes = sch->GetLoops(block); - for (const tir::LoopRV& loop_rv : axes) { + bool InThreadScope(const s_tir::Schedule& sch, const s_tir::SBlockRV& block) { + const ffi::Array& axes = sch->GetLoops(block); + for (const s_tir::LoopRV& loop_rv : axes) { const tir::For& loop = sch->Get(loop_rv); - runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); - if (tir::IsThreadIdx(thread_scope)) { + runtime::ThreadScope thread_scope = s_tir::GetThreadScope(loop.get()); + if (s_tir::IsThreadIdx(thread_scope)) { return true; } } @@ -150,16 +151,16 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param extent The finding result * \return Whether the find is successful. */ - bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop, - tir::ExprRV* extent) { - for (const tir::Instruction& inst : trace->insts) { + bool GetLoopRVExtentSource(const s_tir::Trace& trace, const s_tir::LoopRV& loop, + s_tir::ExprRV* extent) { + for (const s_tir::Instruction& inst : trace->insts) { if (inst->kind->name == "Split") { auto fcheck = [&](const Any& a) -> bool { return a.as() == loop.get(); }; int i = std::find_if(inst->outputs.begin(), inst->outputs.end(), fcheck) - inst->outputs.begin(); CHECK(inst->inputs[1 + i] != nullptr) << "ValueError: Extracting an extent which needs inference is not supported so far"; - *extent = Downcast(inst->inputs[1 + i]); + *extent = Downcast(inst->inputs[1 + i]); return true; } } @@ -171,11 +172,11 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param trace The trace of the schedule, where the extent is to be found * \return The extent of "threadIdx.x" in the input schedule */ - tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { - tir::ExprRV extent{ffi::UnsafeInit()}; - for (const tir::Instruction& inst : trace->insts) { + s_tir::ExprRV GetThreadIdxExtentFromTrace(const s_tir::Trace& trace) { + s_tir::ExprRV extent{ffi::UnsafeInit()}; + for (const s_tir::Instruction& inst : trace->insts) { if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { - if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { + if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { return extent; } } @@ -194,23 +195,24 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * 3. the first block under the target loop when fusible, or a null block random variable; * 4. the innermost loop outside the target block when fusible, or a null block random variable. */ - std::tuple GetComputeTargetLoopAndBlock( - const tir::Schedule& sch, const tir::SBlockRV& block_rv) { + std::tuple GetComputeTargetLoopAndBlock( + const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) { // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing // a tuple reduction, fusion is temporarily not supported. if (sch->Get(block_rv)->writes.size() != 1) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, - tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); + return std::make_tuple(false, s_tir::LoopRV{ffi::UnsafeInit()}, + s_tir::SBlockRV{ffi::UnsafeInit()}, s_tir::LoopRV{ffi::UnsafeInit()}); } // Step 1. Get all the consumers of the input block. - ffi::Array consumers = sch->GetConsumers(block_rv); + ffi::Array consumers = sch->GetConsumers(block_rv); // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is // not fusible. - if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, - tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); + if (consumers.empty() || + s_tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { + return std::make_tuple(false, s_tir::LoopRV{ffi::UnsafeInit()}, + s_tir::SBlockRV{ffi::UnsafeInit()}, s_tir::LoopRV{ffi::UnsafeInit()}); } // Step 3. Calculate the lowest common ancestor of all the consumers. @@ -220,20 +222,20 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // fusible; // - If the lowest common ancestor is a loop, the target block is also the first consumer. const tir::StmtSRef& lca_sref = - tir::GetSRefLowestCommonAncestor(tir::SBlockRVs2StmtSRefs(sch, consumers)); + s_tir::GetSRefLowestCommonAncestor(s_tir::SBlockRVs2StmtSRefs(sch, consumers)); if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, - tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); + return std::make_tuple(false, s_tir::LoopRV{ffi::UnsafeInit()}, + s_tir::SBlockRV{ffi::UnsafeInit()}, s_tir::LoopRV{ffi::UnsafeInit()}); } // Step 4. Get the outer loops of the target block, and get the compute-at position index. - ffi::Array tgt_block_loops = sch->GetLoops(consumers[0]); + ffi::Array tgt_block_loops = sch->GetLoops(consumers[0]); int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); // Step 5. A negative position index means not fusible, and vice-versa. if (pos < 0) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, - tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); + return std::make_tuple(false, s_tir::LoopRV{ffi::UnsafeInit()}, + s_tir::SBlockRV{ffi::UnsafeInit()}, s_tir::LoopRV{ffi::UnsafeInit()}); } else { return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); } @@ -250,14 +252,14 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param lca_sref The lowest common ancestor of all the consumers of the input block * \return The compute-at position index of the input block */ - int GetComputePosition(const tir::Schedule& sch, const ffi::Array& block_loops, - const ffi::Array& tgt_block_loops, + int GetComputePosition(const s_tir::Schedule& sch, const ffi::Array& block_loops, + const ffi::Array& tgt_block_loops, const tir::StmtSRef& lca_sref) { int n_block_loop = static_cast(block_loops.size()); int n_tgt_block_loop = static_cast(tgt_block_loops.size()); for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { - if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { + if (s_tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { return i - 1; } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { // If the lowest common ancestor is a loop, the compute location of the input block should diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index c1002b0ce2c0..62bdb4b09bb8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -28,7 +28,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; std::vector GetReadBufferNDims(const StmtSRef& block_sref) { const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); @@ -44,20 +45,23 @@ std::vector GetReadBufferNDims(const StmtSRef& block_sref) { return results; } -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { namespace meta_schedule { +using s_tir::GetSBlockVarTypes; +using s_tir::IsWriteCache; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; using tir::IterVarType; -using tir::LoopRV; -using tir::SBlockRV; -using tir::Schedule; TVM_FFI_STATIC_INIT_BLOCK() { MultiLevelTilingNode::RegisterReflection(); } -State::State(tir::Schedule sch, tir::SBlockRV block_rv, ffi::Array> tiles) { +State::State(s_tir::Schedule sch, s_tir::SBlockRV block_rv, + ffi::Array> tiles) { ObjectPtr node = ffi::make_object(); node->sch = std::move(sch); node->block_rv = std::move(block_rv); @@ -139,7 +143,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { } std::vector levels = config.levels; ReuseType req = config.req; - if (ffi::Optional> ann = tir::GetAnn>( + if (ffi::Optional> ann = s_tir::GetAnn>( state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) { req = ReuseType::kMustReuse; levels.clear(); @@ -181,14 +185,14 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -std::pair, ffi::Array> MultiLevelTilingNode::SplitLoop( +std::pair, ffi::Array> MultiLevelTilingNode::SplitLoop( const Schedule& sch, SBlockRV block, LoopRV loop, int n_tiles) const { - ffi::Array factors = sch->SamplePerfectTile( + ffi::Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); - ffi::Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } @@ -214,7 +218,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, ffi::Array skipped_outer_spatial_loops; std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); - std::vector> tile_factors; + std::vector> tile_factors; tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; @@ -228,7 +232,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, } idx = &s_indices_; if (spatial_loop_product != -1) { - if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + if (const int64_t* extent = s_tir::GetLoopIntExtent(sch->Get(loop).get())) { spatial_loop_product *= *extent; } else { spatial_loop_product = -1; @@ -298,7 +302,7 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { Schedule& sch = new_state->sch; const LoopRV& loop_rv = state->tiles[level - 1].back(); // Enumerate all buffers that are read but not written - std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + std::vector read_buffer_ndims = s_tir::GetReadBufferNDims(sch->GetSRef(block_rv)); for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { int buffer_ndim = read_buffer_ndims[i]; if (buffer_ndim == -1) { @@ -358,7 +362,7 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { } void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, - const tir::SBlockRV& block) const { + const s_tir::SBlockRV& block) const { // Filter out invalid vector lanes according to the data type. const tir::SBlockNode* block_node = (*sch)->GetSRef(block)->StmtAs(); ICHECK_EQ(block_node->writes.size(), 1); @@ -385,7 +389,7 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, if (!valid_vector_lens.empty()) { int n = valid_vector_lens.size(); double prob = 1.0 / n; - tir::ExprRV vector_load_len = + s_tir::ExprRV vector_load_len = (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), ffi::Array(n, FloatImm(DataType::Float(32), prob))); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 19bfbd51c187..f23872232745 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include @@ -30,7 +30,8 @@ #include "../../support/array.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction * buffers' dimensions as -1 @@ -41,7 +42,7 @@ namespace tir { */ std::vector GetReadBufferNDims(const StmtSRef& block_sref); -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { @@ -105,17 +106,17 @@ class State; class StateNode : public Object { public: /*! \brief The schedule to date */ - tir::Schedule sch; + s_tir::Schedule sch; /*! \brief The block to be tiled */ - tir::SBlockRV block_rv; + s_tir::SBlockRV block_rv; /*! \brief The loop tiles */ - ffi::Array> tiles; + ffi::Array> tiles; /*! \brief The factors of the loop tiles. */ - ffi::Array> tile_factors; + ffi::Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ - std::unordered_map read_reuse; + std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ - std::unordered_map write_reuse; + std::unordered_map write_reuse; /*! * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that @@ -131,8 +132,8 @@ class StateNode : public Object { class State : public ObjectRef { public: /*! \brief Default constructor */ - explicit State(tir::Schedule sch, tir::SBlockRV block_rv, - ffi::Array> tiles = {}); + explicit State(s_tir::Schedule sch, s_tir::SBlockRV block_rv, + ffi::Array> tiles = {}); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(State, ObjectRef, StateNode); }; @@ -174,7 +175,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final; // Entry of the mega rule; Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) override; + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) override; // Inherited from ScheduleRuleNode ScheduleRule Clone() const override; @@ -182,11 +184,11 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual std::pair, ffi::Array> SplitLoop( - const tir::Schedule& sch, tir::SBlockRV block, tir::LoopRV loop, int n_tiles) const; + virtual std::pair, ffi::Array> SplitLoop( + const s_tir::Schedule& sch, s_tir::SBlockRV block, s_tir::LoopRV loop, int n_tiles) const; // Annotate a block to use cooperative fetching - void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::SBlockRV& block) const; + void AnnotateCooperativeFetching(s_tir::Schedule* sch, const s_tir::SBlockRV& block) const; public: /*! diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 85705ea99876..5860d0a3f4ba 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -30,10 +30,11 @@ namespace tvm { namespace meta_schedule { +using s_tir::GetSBlockVarTypes; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; using tir::IterVarType; -using tir::LoopRV; -using tir::SBlockRV; -using tir::Schedule; struct TensorCoreIntrinGroup { ffi::String init_intrin; @@ -77,13 +78,13 @@ class TensorCoreStateNode : public StateNode { /*! \brief The tensor core intrinsic group. */ TensorCoreIntrinGroup intrin_group; /*! \brief The auto tensorization maping info. */ - tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()}; + s_tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()}; /*! \brief The Tensor Core reindex block A for Tensor Core computation */ - tir::SBlockRV tensor_core_reindex_A; + s_tir::SBlockRV tensor_core_reindex_A; /*! \brief The Tensor Core reindex block B for Tensor Core computation */ - tir::SBlockRV tensor_core_reindex_B; + s_tir::SBlockRV tensor_core_reindex_B; /*! \brief The Tensor Core reindex store block for Tensor Core computation */ - tir::SBlockRV tensor_core_reindex_store; + s_tir::SBlockRV tensor_core_reindex_store; /*! \brief Flag to indicate its a WMMA or MMA intrin group */ bool is_mma; /*! \brief Flag to indicate whether to use async software pipeline */ @@ -103,15 +104,15 @@ class TensorCoreStateNode : public StateNode { class TensorCoreState : public State { public: explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, - tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, + s_tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, SBlockRV block_rv, bool use_async, - ffi::Array> tiles = {}); + ffi::Array> tiles = {}); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorCoreState, State, TensorCoreStateNode); }; TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, - tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, + s_tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, SBlockRV block_rv, bool use_async, ffi::Array> tiles) { ObjectPtr node = ffi::make_object(); @@ -153,7 +154,7 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { // Subrule: Add software pipeline inline std::vector AddSoftwarePipeline(TensorCoreState state) const; // Subrule: split loop for mma using sample partitioned tile - inline std::pair, ffi::Array> MMASplitLoop( + inline std::pair, ffi::Array> MMASplitLoop( const Schedule& sch, SBlockRV block, LoopRV loop, int n_tiles, int partition_pos, int innerpart_factor) const; // Subrule: tile loop nest for mma @@ -216,12 +217,13 @@ ffi::Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, return {sch}; } - std::unordered_map intrin_group_to_mapping_info; + std::unordered_map intrin_group_to_mapping_info; for (int i = 0, n = intrin_groups.size(); i < n; ++i) { TensorCoreIntrinGroup intrin_group = intrin_groups[i]; - ffi::Optional mapping_info = tir::GetAutoTensorizeMappingInfo( - sch->state(), sch->GetSRef(block_rv), - tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); + ffi::Optional mapping_info = + s_tir::GetAutoTensorizeMappingInfo( + sch->state(), sch->GetSRef(block_rv), + tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); if (mapping_info.defined()) { intrin_group_to_mapping_info.emplace(i, mapping_info.value()); } @@ -239,7 +241,7 @@ ffi::Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, std::vector initial_states; for (const auto& kv : intrin_group_to_mapping_info) { const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first]; - const tir::AutoTensorizeMappingInfo& mapping_info = kv.second; + const s_tir::AutoTensorizeMappingInfo& mapping_info = kv.second; Schedule new_sch = sch->Copy(); new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv, true)); @@ -288,7 +290,7 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); + ffi::Optional loop = s_tir::TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); ICHECK(loop.defined()); SBlockRV blockized_outer = (*sch)->Blockize(loop.value()); (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); @@ -311,7 +313,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta Schedule& sch = new_state->sch; const LoopRV& loop_rv = state->tiles[level - 1].back(); // Enumerate all buffers that are read but not written - std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + std::vector read_buffer_ndims = s_tir::GetReadBufferNDims(sch->GetSRef(block_rv)); for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { int buffer_ndim = read_buffer_ndims[i]; if (buffer_ndim == -1) { @@ -331,17 +333,17 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta return results; } -std::pair, ffi::Array> +std::pair, ffi::Array> MultiLevelTilingTensorCoreNode::MMASplitLoop(const Schedule& sch, SBlockRV block, LoopRV loop, int n_tiles, int partition_pos, int innerpart_factor) const { - ffi::Array factors = sch->SamplePartitionedTile( + ffi::Array factors = sch->SamplePartitionedTile( /*loop=*/loop, /*n=*/n_tiles, /*partition_pos=*/partition_pos, /*innerpart_factor=*/innerpart_factor); - ffi::Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } @@ -360,7 +362,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta int64_t spatial_loop_product = 1; std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); - std::vector> tile_factors; + std::vector> tile_factors; tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; @@ -369,7 +371,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta if (iter_types[i] == IterVarType::kDataPar) { idx = &s_indices_; if (spatial_loop_product != -1) { - if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + if (const int64_t* extent = s_tir::GetLoopIntExtent(sch->Get(loop).get())) { spatial_loop_product *= *extent; } else { spatial_loop_product = -1; @@ -460,7 +462,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. // `loop_idx` can be negative, in which case it is counted from the end. auto f_get_inner_tile_product = [&](int loop_idx) { - ffi::Array factors; + ffi::Array factors; for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { auto s_factors = state->tile_factors[s_indices_[i]]; if (loop_idx < 0) { @@ -515,7 +517,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa result.push_back(accum_n); return result; }); - sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map, + sch->TransformLayout(state->block_rv, 0, s_tir::BufferIndexType::kWrite, index_map, /*pad_value=*/std::nullopt, /*assume_injective_transform=*/true); return {state}; @@ -613,12 +615,13 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( state->intrin_group.load_b_intrin); for (int i = 0; i < 2; ++i) { - const tir::SBlockRV cache_read = state->read_reuse.at(i); + const s_tir::SBlockRV cache_read = state->read_reuse.at(i); // Inline the reindex / padding block sch->ComputeInline(sch->GetProducers(cache_read)[0]); const tir::SBlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); - tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( - sch->state(), ffi::GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); + tir::Buffer cache_read_buffer = + s_tir::GetNthAccessBuffer(sch->state(), ffi::GetRef(cache_read_block), 0, + s_tir::BufferIndexType::kWrite); const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { sch->StorageAlign(cache_read, 0, -2, 32, 8); @@ -658,7 +661,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( } for (int i = 0; i < 2; ++i) { - const tir::SBlockRV cache_read = state->read_reuse.at(i); + const s_tir::SBlockRV cache_read = state->read_reuse.at(i); if (state->is_mma) { // Add vector bytes for memhammer sch->Annotate(cache_read, tir::attr::vector_bytes, Integer(16)); @@ -764,7 +767,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( TensorCoreStateNode* state, const ffi::String& intrin_name) const { SBlockRV block_rv = state->block_rv; - const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; + const s_tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); // Add reindex stages @@ -776,11 +779,11 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( return std::nullopt; } state->tensor_core_reindex_store = - state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite); + state->sch->ReIndex(state->block_rv, 0, s_tir::BufferIndexType::kWrite); state->tensor_core_reindex_A = - state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead); + state->sch->ReIndex(state->block_rv, 0, s_tir::BufferIndexType::kRead); state->tensor_core_reindex_B = - state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead); + state->sch->ReIndex(state->block_rv, 1, s_tir::BufferIndexType::kRead); // Transform the layout of reindex buffers accordingly. // The index map defines the mapping for the computation block. We need to extract the sub index @@ -840,8 +843,8 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( ffi::Map buffer_sub_index_map; // cache of the sub index map // associated with each buffer - auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) { - const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer( + auto f_transform_buffer_layout = [&](s_tir::BufferIndexType index_type, int buffer_index) { + const tir::Buffer& lhs_buffer = s_tir::GetNthAccessBuffer( state->sch->state(), block_before_reindex, buffer_index, index_type); if (visited_buffers.count(lhs_buffer)) { return; @@ -849,7 +852,7 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( visited_buffers.insert(lhs_buffer); // Refresh block pointer (block sref is not invalidated) block = TVM_SREF_TO_SBLOCK(block_sref); - const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( + const tir::BufferRegion& reindexed_buffer_region = s_tir::GetNthAccessBufferRegion( state->sch->state(), ffi::GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); @@ -858,26 +861,27 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { - f_transform_buffer_layout(tir::BufferIndexType::kRead, i); + f_transform_buffer_layout(s_tir::BufferIndexType::kRead, i); } for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) { - f_transform_buffer_layout(tir::BufferIndexType::kWrite, i); + f_transform_buffer_layout(s_tir::BufferIndexType::kWrite, i); } // Transform the layout of current block and reindex blocks auto f_transform_reindex_block_layout = [&](const SBlockRV& block_rv, - tir::BufferIndexType buffer_type) { + s_tir::BufferIndexType buffer_type) { tir::Buffer buffer = - tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); + s_tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); const auto& sub_index_map = buffer_sub_index_map.at(buffer); state->sch->TransformBlockLayout(block_rv, sub_index_map); }; - f_transform_reindex_block_layout(state->tensor_core_reindex_store, tir::BufferIndexType::kWrite); - f_transform_reindex_block_layout(state->tensor_core_reindex_A, tir::BufferIndexType::kRead); - f_transform_reindex_block_layout(state->tensor_core_reindex_B, tir::BufferIndexType::kRead); + f_transform_reindex_block_layout(state->tensor_core_reindex_store, + s_tir::BufferIndexType::kWrite); + f_transform_reindex_block_layout(state->tensor_core_reindex_A, s_tir::BufferIndexType::kRead); + f_transform_reindex_block_layout(state->tensor_core_reindex_B, s_tir::BufferIndexType::kRead); state->sch->TransformBlockLayout(state->block_rv, index_map); - return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, - /*allow_padding=*/true); + return s_tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, + /*allow_padding=*/true); } inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorization( diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 8a1ac2bae8d4..1bcbeea5c757 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -19,17 +19,17 @@ #include -#include "../../tir/schedule/analysis.h" -#include "../../tir/schedule/transform.h" +#include "../../s_tir/schedule/analysis.h" +#include "../../s_tir/schedule/transform.h" #include "../utils.h" #include "multi_level_tiling.h" namespace tvm { namespace meta_schedule { -using tir::LoopRV; -using tir::SBlockRV; -using tir::Schedule; +using s_tir::LoopRV; +using s_tir::SBlockRV; +using s_tir::Schedule; /*! * \brief Extension of MultiLevelTiling for backends with wide vectors. @@ -55,18 +55,19 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { return ScheduleRule(n); } - std::pair, ffi::Array> SplitLoop(const Schedule& sch, - SBlockRV block, LoopRV loop, - int n_tiles) const; + std::pair, ffi::Array> SplitLoop(const Schedule& sch, + SBlockRV block, + LoopRV loop, + int n_tiles) const; }; -std::pair, ffi::Array> +std::pair, ffi::Array> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv, LoopRV loop_rv, int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::SBlockNode* block_node = block_sref->StmtAs(); - const tir::SBlockRealize block_realize = tir::GetSBlockRealize(sch->state(), block_sref); + const tir::SBlockRealize block_realize = s_tir::GetSBlockRealize(sch->state(), block_sref); ICHECK(block_node && block_node->writes.size() == 1); const auto out_dtype = block_node->writes[0]->buffer->dtype; @@ -98,25 +99,25 @@ MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles); } else { // We split the innermost spatial loop in a way that always uses the maximum vector length. - const int64_t* extent_int = tir::GetLoopIntExtent(loop); + const int64_t* extent_int = s_tir::GetLoopIntExtent(loop); if (extent_int && *extent_int > vec_len) { - ffi::Array inner_splits = + ffi::Array inner_splits = sch->Split(/*loop=*/loop_rv, /*factors=*/{std::nullopt, PrimExpr(vec_len)}); - ffi::Array outer_factors = sch->SamplePerfectTile( + ffi::Array outer_factors = sch->SamplePerfectTile( /*loop=*/inner_splits[0], /*n=*/n_tiles - 1, /*max_innermost_factor=*/max_innermost_factor); - ffi::Array outer_splits = sch->Split( + ffi::Array outer_splits = sch->Split( /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); outer_splits.push_back(inner_splits[1]); outer_factors.push_back(PrimExpr(vec_len)); return {outer_factors, outer_splits}; } else { - ffi::Array factors(n_tiles - 1, PrimExpr(1)); + ffi::Array factors(n_tiles - 1, PrimExpr(1)); factors.push_back(loop->extent); - ffi::Array splits = sch->Split(/*loop=*/loop_rv, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 8167a6f8974b..23e5d583bc3c 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -19,8 +19,8 @@ #include -#include "../../tir/schedule/analysis.h" -#include "../../tir/schedule/transform.h" +#include "../../s_tir/schedule/analysis.h" +#include "../../s_tir/schedule/transform.h" #include "../utils.h" #include "multi_level_tiling.h" @@ -31,14 +31,14 @@ namespace meta_schedule { * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate * the tiled block for tensorization by postproc rewrite. */ -ffi::Optional TileForIntrin(tir::Schedule sch, tir::SBlockRV block, - const std::string& intrin_name) { - ffi::Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); +ffi::Optional TileForIntrin(s_tir::Schedule sch, s_tir::SBlockRV block, + const std::string& intrin_name) { + ffi::Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); if (!tiled_loop_rv) { return std::nullopt; } ICHECK(tiled_loop_rv.defined()); - tir::SBlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); + s_tir::SBlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, ffi::String(intrin_name)); return outer_block; } @@ -48,7 +48,8 @@ ffi::Optional TileForIntrin(tir::Schedule sch, tir::SBlockRV bloc */ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { protected: - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final { auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index d1e931e42434..1ee5c58f8089 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; bool IsRootBlock(const Schedule& sch, const SBlockRV& block_rv) { StmtSRef block_sref = sch->GetSRef(block_rv); @@ -33,7 +34,7 @@ bool CheckSpatialPrimFunc(const Schedule& sch, const SBlockRV& root_block_rv) { ffi::GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); } -} // namespace tir +} // namespace s_tir } // namespace tvm namespace tvm { @@ -51,9 +52,9 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& root_rv) { + ffi::Array Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& root_rv) { // Currently only mark the root block with annotations. - if (!tir::IsRootBlock(sch, root_rv)) { + if (!s_tir::IsRootBlock(sch, root_rv)) { return {sch}; } @@ -67,7 +68,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent)); } // Unroll - if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { + if (!unroll_max_steps.empty() && !s_tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; ffi::Array probs(n, FloatImm(DataType::Float(32), prob)); diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 89a9f722a816..35f879d1ea5f 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -29,7 +29,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { + ffi::Array Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv) final { if (!CheckConditions(sch, block_rv)) { return {sch}; } @@ -40,16 +41,16 @@ class RandomComputeLocationNode : public ScheduleRuleNode { // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer // access the input block. Hence we collect its producer ahead of time. // - Note that only single producer is allowed in this case. - ffi::Array producers{nullptr}; - if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, - true)) { + ffi::Array producers{nullptr}; + if (s_tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, + true)) { producers = sch->GetProducers(block_rv); sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); ICHECK_EQ(producers.size(), 1); } // Step 2. Transform the input block. - tir::Schedule res = RandomlyComputeAt(sch, block_rv); + s_tir::Schedule res = RandomlyComputeAt(sch, block_rv); // Step 3. Transform the producer block if compute-location sampling is needed. if (producers.defined()) { @@ -66,7 +67,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } private: - bool CheckConditions(const tir::Schedule sch, const tir::SBlockRV& block_rv) const { + bool CheckConditions(const s_tir::Schedule sch, const s_tir::SBlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); TVM_SREF_TO_SBLOCK(block_sref); @@ -75,26 +76,26 @@ class RandomComputeLocationNode : public ScheduleRuleNode { return false; } // Cond 2. The block should be the direct child block of the root block. - if (GetScopeRoot(sch->state(), block_sref, - /*require_stage_pipeline=*/false) + if (s_tir::GetScopeRoot(sch->state(), block_sref, + /*require_stage_pipeline=*/false) ->parent != nullptr) { return false; } // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child // block. - ffi::Array loop_srefs = tir::GetLoops(block_sref); + ffi::Array loop_srefs = s_tir::GetLoops(block_sref); if (loop_srefs.empty()) { return false; } - if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { + if (s_tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { return false; } // Cond 5. The block is not tiled. We check this condition by examine the block's annotation. - if (tir::HasBeenMultiLevelTiled(block_sref)) { + if (s_tir::HasBeenMultiLevelTiled(block_sref)) { return false; } // Cond 6. The block has at lease one consumer. - if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { + if (s_tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { return false; } return true; @@ -106,8 +107,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode { * \param block_rv The block whose compute-at location is to be sampled * \return The TIR schedule after transformation */ - tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::SBlockRV& block_rv) { - tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); + s_tir::Schedule RandomlyComputeAt(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) { + s_tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); sch->ComputeAt(block_rv, compute_at_loc, true); return sch; } diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index f5ef5da48c2d..72c767ac75ff 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -30,8 +30,8 @@ void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -ffi::Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, - const tir::SBlockRV& block) { +ffi::Array PyScheduleRuleNode::Apply(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block) { ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; return f_apply(sch, block); } diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 8aa5aca45059..480e2b48a683 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -29,7 +29,7 @@ namespace tvm { namespace meta_schedule { -using tir::Schedule; +using s_tir::Schedule; /**************** Data Structure ****************/ @@ -125,7 +125,7 @@ struct PerThreadData { */ void Set(const std::vector& scores, double genetic_mutate_prob, const ffi::Map& mutator_probs) { - trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + trace_sampler = s_tir::MakeMultinomialSampler(&rand_state, scores); mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); } @@ -164,7 +164,7 @@ struct PerThreadData { masses[i] /= total_mass_mutator; } } - return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + return [idx_sampler = s_tir::MakeMultinomialSampler(rand_state, masses), mutators = std::move(mutators)]() -> ffi::Optional { int i = idx_sampler(); return mutators[i]; @@ -261,7 +261,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The counter of returning empty results. */ int num_empty_iters; /*! \brief The design spaces. Decisions are not used so traces only. */ - ffi::Array design_spaces; + ffi::Array design_spaces; /*! \brief Pre thread data including module to be tuned and random state. */ std::vector per_thread_data_; /*! @@ -471,7 +471,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase"); - std::vector measured_traces; + std::vector measured_traces; measured_traces.reserve(num); ffi::Array top_records = this->database_->GetTopK(this->token_, num); for (TuningRecord record : top_records) { @@ -485,7 +485,7 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu PerThreadData& data = this->per_thread_data_.at(thread_id); TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; - tir::Trace trace = measured_traces.at(trace_id); + s_tir::Trace trace = measured_traces.at(trace_id); Schedule& result = results.at(trace_id); ICHECK(!result.defined()); if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { @@ -513,8 +513,8 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu const IRModule& mod = data.mod; Schedule& result = results.at(trace_id); ICHECK(!result.defined()); - int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); - tir::Trace trace(design_spaces[design_space_index]->insts, {}); + int design_space_index = s_tir::SampleInt(rand_state, 0, design_spaces.size()); + s_tir::Trace trace(design_spaces[design_space_index]->insts, {}); if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); } @@ -591,11 +591,11 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { sampled_trace_id = trace_sampler(); sampled_trace_id = sampled_trace_id % self->population_size; - tir::Trace trace = population.at(sampled_trace_id)->trace().value(); + s_tir::Trace trace = population.at(sampled_trace_id)->trace().value(); if (ffi::Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); - if (ffi::Optional new_trace = mutator->Apply(trace, rand_state)) { + if (ffi::Optional new_trace = mutator->Apply(trace, rand_state)) { if (ffi::Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { // note that sch's trace is different from new_trace // because it contains post-processing information @@ -657,7 +657,7 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( int num_rands = num * self->eps_greedy; int num_bests = num - num_rands; std::vector rands = - tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + s_tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); std::vector results; results.reserve(num); IRModuleSet& measured_workloads = this->measured_workloads_; diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 9082c6c3a90f..67600960e663 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -85,7 +85,7 @@ class ReplayFuncNode : public SearchStrategyNode { } void PreTuning(int max_trials, int num_trials_per_iter, - const ffi::Array& design_spaces, + const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) final { CHECK(this->state_ == nullptr) @@ -131,9 +131,9 @@ ReplayFuncNode::State::GenerateMeasureCandidates() { ffi::Array postprocs = self->space_generator_.value()->postprocs.value_or({}); for (int i = st; i < ed; i++) { for (;;) { - ffi::Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); - int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); - tir::Schedule sch = schs[design_space_index]; + ffi::Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); + int design_space_index = s_tir::SampleInt(&self->rand_state_, 0, schs.size()); + s_tir::Schedule sch = schs[design_space_index]; sch->EnterPostproc(); bool failed = false; for (const Postproc& proc : postprocs) { diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 7898b171d357..b5f116732ceb 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -31,7 +31,7 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The search strategy itself */ ReplayTraceNode* self; /*! \brief The design spaces. */ - ffi::Array design_spaces; + ffi::Array design_spaces; /*! \brief The number of total trials. */ int max_trials; /*! \brief The number of trials per iteration. */ @@ -44,7 +44,7 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The module to be tuned. */ ffi::Array per_thread_mod_{nullptr}; - explicit State(ReplayTraceNode* self, ffi::Array design_spaces, int max_trials, + explicit State(ReplayTraceNode* self, ffi::Array design_spaces, int max_trials, int num_trials_per_iter) : self(self), design_spaces(design_spaces), @@ -102,15 +102,15 @@ class ReplayTraceNode : public SearchStrategyNode { } void PreTuning(int max_trials, int num_trials_per_iter, - const ffi::Array& design_spaces, + const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) final { ICHECK(!design_spaces.empty()); CHECK(this->state_ == nullptr) << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; - ffi::Array design_space_traces; + ffi::Array design_space_traces; design_space_traces.reserve(design_spaces.size()); - for (const tir::Schedule& space : design_spaces) { + for (const s_tir::Schedule& space : design_spaces) { design_space_traces.push_back(space->trace().value()->Simplified(true)); } this->state_ = @@ -158,11 +158,11 @@ ReplayTraceNode::State::GenerateMeasureCandidates() { IRModule mod = this->per_thread_mod_[thread_id]; for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) { - int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); - tir::Trace trace = design_spaces[design_space_index]; - tir::Trace new_trace = tir::Trace(trace->insts, {}); - if (ffi::Optional opt_sch = pp.Apply(mod, new_trace, &rand_state)) { - tir::Schedule sch = opt_sch.value(); + int design_space_index = s_tir::SampleInt(&rand_state, 0, design_spaces.size()); + s_tir::Trace trace = design_spaces[design_space_index]; + s_tir::Trace new_trace = s_tir::Trace(trace->insts, {}); + if (ffi::Optional opt_sch = pp.Apply(mod, new_trace, &rand_state)) { + s_tir::Schedule sch = opt_sch.value(); ffi::Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); per_task_result.Set(task_id, MeasureCandidate(sch, args_info)); break; diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 3273e70ac1b8..919e78810978 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -23,7 +23,7 @@ namespace tvm { namespace meta_schedule { -MeasureCandidate::MeasureCandidate(tir::Schedule sch, ffi::Array args_info) { +MeasureCandidate::MeasureCandidate(s_tir::Schedule sch, ffi::Array args_info) { ObjectPtr n = ffi::make_object(); n->sch = sch; n->args_info = args_info; @@ -37,7 +37,7 @@ void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) } void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter, - const ffi::Array& design_spaces, + const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) { ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; @@ -94,9 +94,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.MeasureCandidate", - [](tir::Schedule sch, ffi::Optional> args_info) -> MeasureCandidate { - return MeasureCandidate(sch, args_info.value_or({})); - }) + [](s_tir::Schedule sch, ffi::Optional> args_info) + -> MeasureCandidate { return MeasureCandidate(sch, args_info.value_or({})); }) .def("meta_schedule.SearchStrategyPySearchStrategy", SearchStrategy::PySearchStrategy) .def_method("meta_schedule.SearchStrategyInitializeWithTuneContext", &SearchStrategyNode::InitializeWithTuneContext) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 44a365031894..d4e0dcb7251b 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -46,21 +46,21 @@ class PostOrderApplyNode : public SpaceGeneratorNode { this->rand_state_ = ForkSeed(&context->rand_state); } - ffi::Array GenerateDesignSpace(const IRModule& mod) final { - using ScheduleAndUnvisitedBlocks = std::pair>; + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + using ScheduleAndUnvisitedBlocks = std::pair>; CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply"; - tir::Schedule sch = tir::Schedule::Traced( + s_tir::Schedule sch = s_tir::Schedule::Traced( /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), /*debug_mode=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; - ffi::Array result{sch}; - ffi::Array all_blocks = SBlockCollector::Collect(sch, f_block_filter_); + ffi::Array result{sch}; + ffi::Array all_blocks = SBlockCollector::Collect(sch, f_block_filter_); for (ScheduleRule sch_rule : sch_rules.value()) { - for (const tir::Schedule& sch : result) { + for (const s_tir::Schedule& sch : result) { stack.emplace_back(sch, all_blocks); } result.clear(); @@ -74,20 +74,20 @@ class PostOrderApplyNode : public SpaceGeneratorNode { continue; } // otherwise, get the last block that is not visited - tir::SBlockRV block_rv = blocks.back(); + s_tir::SBlockRV block_rv = blocks.back(); blocks.pop_back(); if (!sch->HasBlock(block_rv)) { stack.emplace_back(sch, blocks); continue; } if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { - if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { + if (s_tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { stack.emplace_back(sch, blocks); continue; } } - ffi::Array applied = sch_rule->Apply(sch, /*block=*/block_rv); - for (const tir::Schedule& sch : applied) { + ffi::Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + for (const s_tir::Schedule& sch : applied) { stack.emplace_back(sch, blocks); } } diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 7d22635b76f2..33bb75e97288 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -42,26 +42,26 @@ class ScheduleFnNode : public SpaceGeneratorNode { this->rand_state_ = ForkSeed(&context->rand_state); } - ffi::Array GenerateDesignSpace(const IRModule& mod) final { - tir::Schedule sch = tir::Schedule::Traced( + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + s_tir::Schedule sch = s_tir::Schedule::Traced( /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), /*debug_mode=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kDetail); ffi::Any rv; rv = this->schedule_fn_(sch); if (rv == nullptr) { return {sch}; } ObjectRef obj = rv.cast(); - if (auto sch = obj.as()) { + if (auto sch = obj.as()) { return {sch.value()}; } if (const auto* arr = obj.as()) { - ffi::Array result; + ffi::Array result; result.reserve(arr->size()); for (Any val : *arr) { - if (auto sch = val.as()) { + if (auto sch = val.as()) { result.push_back(sch.value()); } else { LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 9e458a3ad7cf..26137e19a35e 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -174,7 +174,7 @@ void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) f_initialize_with_tune_context(context); } -ffi::Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { +ffi::Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { ICHECK(f_generate_design_space != nullptr) << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 026daa68a762..3f26b8d02afe 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -42,11 +42,11 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { } } - ffi::Array GenerateDesignSpace(const IRModule& mod) final { - ffi::Array design_spaces; + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + ffi::Array design_spaces; for (const SpaceGenerator& space_generator : space_generators) { // Generate partial design spaces from each design space generator. - ffi::Array partial = space_generator->GenerateDesignSpace(mod); + ffi::Array partial = space_generator->GenerateDesignSpace(mod); // Merge the partial design spaces. design_spaces.insert(design_spaces.end(), partial.begin(), partial.end()); } diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index babf521c280c..686aeaa275a1 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -111,7 +111,7 @@ class GradientBasedNode final : public TaskSchedulerNode { auto min_grad = std::min_element(grad.begin(), grad.end()); int task_id = -1; if (*max_grad == *min_grad) { - task_id = tasks_alive[tir::SampleInt(&this->rand_state, 0, tasks_alive.size())]; + task_id = tasks_alive[s_tir::SampleInt(&this->rand_state, 0, tasks_alive.size())]; } else { task_id = tasks_alive[std::distance(grad.begin(), max_grad)]; } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 85c6d71b4307..3576c5f40467 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -124,7 +124,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const ffi::Arraylatency_ms.push_back(run_ms); if (error_msg) { - const tir::Schedule& sch = candidate->sch; + const s_tir::Schedule& sch = candidate->sch; std::string err = error_msg.value(); TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // << "[Task #" << task_id << ": " << name << "] Trial #" << trials @@ -168,13 +168,13 @@ void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name; TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name; this->tasks_.push_back(TaskRecord(ctx, weight)); - ffi::Array design_spaces = + ffi::Array design_spaces = ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size() << " design space(s) generated"; for (int i = 0, n = design_spaces.size(); i < n; ++i) { - tir::Schedule sch = design_spaces[i]; - tir::Trace trace = sch->trace().value(); + s_tir::Schedule sch = design_spaces[i]; + s_tir::Trace trace = sch->trace().value(); trace = trace->Simplified(true); TVM_PY_LOG(INFO, ctx->logger) << "Design space #" << i << ":\n" << sch->mod() << "\n" diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index aef04f7da19c..0b13e4d07641 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -29,13 +29,25 @@ #include #include -#include "../tir/schedule/analysis.h" +#include "../s_tir/schedule/analysis.h" #include "utils.h" namespace tvm { namespace meta_schedule { using namespace tir; +using s_tir::GetSBlockNames; +using s_tir::Instruction; +using s_tir::InstructionKind; +using s_tir::IsSpatial; +using s_tir::LoopRV; +using s_tir::LoopRVNode; +using s_tir::SBlockRV; +using s_tir::SBlockRVNode; +using s_tir::Schedule; +using s_tir::Trace; +using s_tir::TranslateAddOutputRVs; +using s_tir::TranslateInputRVs; // Returns true if b1 is an ancestor of b2 bool IsAncestor(SBlockRV b1, SBlockRV b2, Schedule sch) { diff --git a/src/meta_schedule/trace_apply.h b/src/meta_schedule/trace_apply.h index 9a9068ab914f..6e95856a4a7a 100644 --- a/src/meta_schedule/trace_apply.h +++ b/src/meta_schedule/trace_apply.h @@ -20,9 +20,9 @@ #define TVM_META_SCHEDULE_TRACE_APPLY_H_ #include +#include +#include #include -#include -#include #include @@ -39,7 +39,7 @@ namespace meta_schedule { * \param anchor_trace The trace tuned on other subgraph with the same anchor-block workload. * \param target The target information needed for inlining and parallelization. */ -void ScheduleUsingAnchorTrace(tir::Schedule sch, const tir::Trace& anchor_trace, +void ScheduleUsingAnchorTrace(s_tir::Schedule sch, const s_tir::Trace& anchor_trace, const tvm::Target& target); } // namespace meta_schedule diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 0ed65afa2fb2..9fd803b87237 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -38,8 +38,8 @@ #include #include #include +#include #include -#include #include #include @@ -49,13 +49,13 @@ #include #include +#include "../s_tir/schedule/primitive.h" +#include "../s_tir/schedule/utils.h" #include "../support/array.h" #include "../support/base64.h" #include "../support/nd_int_set.h" #include "../support/table_printer.h" #include "../support/utils.h" -#include "../tir/schedule/primitive.h" -#include "../tir/schedule/utils.h" #define TVM_PY_LOG(logging_level, logger) \ ::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logger, \ @@ -291,8 +291,8 @@ inline std::string Concat(const ffi::Array& strs, const std::string * \param global_var_name The global variable name * \return The SBlockRV */ -inline tir::SBlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, - const ffi::String& global_var_name) { +inline s_tir::SBlockRV GetRVFromSRef(const s_tir::Schedule& sch, const tir::StmtSRef& block_sref, + const ffi::String& global_var_name) { const tir::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); return sch->GetSBlock(block->name_hint, global_var_name); } @@ -321,13 +321,13 @@ struct ThreadedTraceApply { * \param rand_state The random seed * \return The schedule created, or std::nullopt if any postprocessor fails */ - ffi::Optional Apply(const IRModule& mod, const tir::Trace& trace, - TRandState* rand_state) { - tir::Schedule sch = - tir::Schedule::Traced(mod, - /*rand_state=*/ForkSeed(rand_state), - /*debug_mode=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + ffi::Optional Apply(const IRModule& mod, const s_tir::Trace& trace, + TRandState* rand_state) { + s_tir::Schedule sch = + s_tir::Schedule::Traced(mod, + /*rand_state=*/ForkSeed(rand_state), + /*debug_mode=*/0, + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone); trace->ApplyToSchedule(sch, /*remove_postproc=*/true); sch->EnterPostproc(); @@ -339,7 +339,7 @@ struct ThreadedTraceApply { if (!item.postproc->Apply(sch)) { success = false; } - } catch (const tir::ScheduleError& e) { + } catch (const s_tir::ScheduleError& e) { DLOG(WARNING) << "Postproc #" << i << " failed with ScheduleError: " << e.what(); success = false; } catch (const std::exception& e) { @@ -580,15 +580,15 @@ inline double Sum(const ffi::Array& arr) { /*! \brief Collecting all the blocks */ class SBlockCollector : public tir::StmtVisitor { public: - static ffi::Array Collect(const tir::Schedule& sch, - const ffi::Function f_block_filter = nullptr) { // + static ffi::Array Collect(const s_tir::Schedule& sch, + const ffi::Function f_block_filter = nullptr) { // return SBlockCollector(sch, f_block_filter).Run(); } private: /*! \brief Entry point */ - ffi::Array Run() { - std::vector results; + ffi::Array Run() { + std::vector results; auto f_collect = [this, &results](tir::PrimFunc func, ffi::String func_name) { func_name_ = func_name; block_names_.clear(); @@ -615,7 +615,7 @@ class SBlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit SBlockCollector(const tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) + explicit SBlockCollector(const s_tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::SBlockNode* block) override { @@ -637,7 +637,7 @@ class SBlockCollector : public tir::StmtVisitor { } /*! \brief The schedule to be collected */ - const tir::Schedule& sch_; + const s_tir::Schedule& sch_; /*! \brief An optional packed func that allows only certain blocks to be collected. */ const ffi::Function f_block_filter_; /*! \brief The set of func name and block name pair */ diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index aaac39c61b20..e9ff5e72f2b9 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -30,7 +30,7 @@ #include #include -#include "../../../tir/schedule/transform.h" +#include "../../../s_tir/schedule/transform.h" #include "../../op/ccl/ccl.h" #include "../../op/distributed/distributed.h" diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 7930e2dfe7fc..676fce094a5b 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -32,7 +32,7 @@ #include #include -#include "../../../tir/schedule/transform.h" +#include "../../../s_tir/schedule/transform.h" #include "../../op/ccl/ccl.h" #include "../../op/tensor/manipulate.h" #include "utils.h" diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index a21304b90152..7e805a4f512f 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -28,11 +28,12 @@ #include #include -#include "../../../tir/schedule/transform.h" +#include "../../../s_tir/schedule/transform.h" #include "utils.h" namespace tvm { namespace tir { using namespace tvm::relax::distributed; +using s_tir::ReplaceBuffer; class DistBufferReplacer : public StmtExprMutator { public: diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index dd5b93267476..221d902a0d44 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -106,19 +106,19 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ if (ffi::Optional opt_record = database->QueryTuningRecord(tir_mod, target, gv->name_hint)) { meta_schedule::TuningRecord record = opt_record.value(); - tir::Schedule sch{nullptr}; + s_tir::Schedule sch{nullptr}; if (!mod_eq_structural->Equal(tir_mod, record->workload->mod)) { // When the database lookup succeeds while structural equality check fails, // it implies that the anchor block based equality has been used during tuning. // The trace in the record cannot directly be applied to this query module. - sch = tir::Schedule::Traced( + sch = s_tir::Schedule::Traced( tir_mod, /*seed=*/-1, /*debug_mask=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kDetail); meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, target); } else { - sch = tir::Schedule::Traced( + sch = s_tir::Schedule::Traced( record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kDetail); record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); } IRModule new_mod = sch->mod(); diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 376759984816..ce694a9fd7b4 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -31,7 +31,7 @@ #include #include -#include "../../tir/schedule/ir_comparator.h" +#include "../../s_tir/schedule/ir_comparator.h" namespace tvm { @@ -45,6 +45,8 @@ namespace tir { using relax::FCodegen; using relax::MatchResult; using relax::TIRPattern; +using s_tir::ExprComparator; +using s_tir::TensorizeComparator; /*! \brief helper to match a for stmt to a pattern*/ class ForMatcher : public TensorizeComparator { diff --git a/src/tir/schedule/analysis.h b/src/s_tir/schedule/analysis.h similarity index 98% rename from src/tir/schedule/analysis.h rename to src/s_tir/schedule/analysis.h index a68da6d5b883..e059a912489d 100644 --- a/src/tir/schedule/analysis.h +++ b/src/s_tir/schedule/analysis.h @@ -16,14 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_ -#define TVM_TIR_SCHEDULE_ANALYSIS_H_ +#ifndef TVM_S_TIR_SCHEDULE_ANALYSIS_H_ +#define TVM_S_TIR_SCHEDULE_ANALYSIS_H_ #include #include +#include +#include #include -#include -#include #include #include @@ -34,7 +34,8 @@ #include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Verification ********/ /*! @@ -687,9 +688,9 @@ bool IsSpatialPrimFunc(const PrimFunc& func); * \param max_parallel_basic The maximum cores on the target. * \return A boolean indicating whether the operation is beneficial. */ -bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // - const tir::StmtSRef& block_sref, // - int64_t max_parallel_extent, // +bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // int64_t max_parallel_basic); /*! @@ -754,7 +755,7 @@ class TensorizeInfoNode : public Object { .def_ro("desc_loop_indexer", &TensorizeInfoNode::desc_loop_indexer) .def_ro("block_iter_paddings", &TensorizeInfoNode::block_iter_paddings); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.schedule.TensorizeInfo", TensorizeInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.schedule.TensorizeInfo", TensorizeInfoNode, Object); }; class TensorizeInfo : public ObjectRef { @@ -773,7 +774,7 @@ class TensorizeInfo : public ObjectRef { * \param allow_padding Whether to allow padding the block iters to match the intrinsic description * \return TensorizeInfo structure if a valid mapping is found, std::nullopt otherwise */ -ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, +ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func, bool allow_padding); @@ -804,7 +805,7 @@ class AutoTensorizeMappingInfoNode : public Object { .def_ro("lhs_iters", &AutoTensorizeMappingInfoNode::lhs_iters) .def_ro("rhs_iters", &AutoTensorizeMappingInfoNode::rhs_iters); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.schedule.AutoTensorizeMappingInfo", + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.schedule.AutoTensorizeMappingInfo", AutoTensorizeMappingInfoNode, Object); }; @@ -841,9 +842,9 @@ ffi::Optional GetAutoTensorizeMappingInfo(const Schedu * \param desc_func The prim func describing the computation to be tensorized * \return true if basic conditions are met. */ -bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::SBlockRV& block_rv, +bool CheckAutoTensorizeApplicable(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv, const tir::PrimFunc& desc_func); -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_ANALYSIS_H_ +#endif // TVM_S_TIR_SCHEDULE_ANALYSIS_H_ diff --git a/src/tir/schedule/analysis/analysis.cc b/src/s_tir/schedule/analysis/analysis.cc similarity index 97% rename from src/tir/schedule/analysis/analysis.cc rename to src/s_tir/schedule/analysis/analysis.cc index ee193b80b3cb..a00958e5246e 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/s_tir/schedule/analysis/analysis.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; TVM_FFI_STATIC_INIT_BLOCK() { TensorizeInfoNode::RegisterReflection(); @@ -337,8 +338,8 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.IsReductionBlock", [](Schedule sch, SBlockRV block_rv, - SBlockRV scope_block_rv) { + refl::GlobalDef().def("s_tir.schedule.IsReductionBlock", [](Schedule sch, SBlockRV block_rv, + SBlockRV scope_block_rv) { return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); }); } @@ -879,7 +880,7 @@ SBlockRealize GetSBlockRealize(const ScheduleState& self, const StmtSRef& block_ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.GetSBlockRealize", [](Schedule sch, SBlockRV block_rv) { + refl::GlobalDef().def("s_tir.schedule.GetSBlockRealize", [](Schedule sch, SBlockRV block_rv) { return GetSBlockRealize(sch->state(), sch->GetSRef(block_rv)); }); } @@ -952,8 +953,7 @@ StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs) { } bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { - return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure) - .has_value(); + return GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).has_value(); } std::pair, std::vector> CollectComputeLocation( @@ -1502,7 +1502,7 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.IsTrivialBinding", [](Schedule sch, SBlockRV block_rv) { + refl::GlobalDef().def("s_tir.schedule.IsTrivialBinding", [](Schedule sch, SBlockRV block_rv) { return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); }); } @@ -1587,9 +1587,9 @@ bool IsSpatialPrimFunc(const PrimFunc& func) { return result; } -std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, +std::pair GetCumulativeSpaceAndReductionLength(const s_tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - ffi::Array loops = tir::GetLoops(block_sref); + ffi::Array loops = GetLoops(block_sref); int64_t cum_space_len = 1, cum_reduce_len = 1; /* * Return (-1, -1) if @@ -1619,12 +1619,12 @@ std::pair GetCumulativeSpaceAndReductionLength(const tir::Sche return std::make_pair(cum_space_len, cum_reduce_len); } -bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // - const tir::StmtSRef& block_sref, // - int64_t max_parallel_extent, // +bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // int64_t max_parallel_basic) { const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); - ffi::Array loops = tir::GetLoops(block_sref); + ffi::Array loops = GetLoops(block_sref); // Cond 1. The block must have at lease one write buffer if (block->writes.size() == 0) { @@ -1747,12 +1747,12 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, return info; } -ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, +ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func, bool allow_padding) { arith::Analyzer analyzer; - const tir::SBlockRealize& block = tir::GetSBlockRealize(self, block_sref); + const tir::SBlockRealize& block = GetSBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref @@ -1911,9 +1911,9 @@ ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& s TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.IsSpatialPrimFunc", IsSpatialPrimFunc) - .def("tir.schedule.GetTensorizeLoopMapping", [](Schedule sch, SBlockRV block, - PrimFunc desc_func, bool allow_padding) { + .def("s_tir.schedule.IsSpatialPrimFunc", IsSpatialPrimFunc) + .def("s_tir.schedule.GetTensorizeLoopMapping", [](Schedule sch, SBlockRV block, + PrimFunc desc_func, bool allow_padding) { return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); } @@ -2106,21 +2106,21 @@ bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRe // Step 1. Analyze desc_func, extract its block, loops and loop vars // Step 2. Check if `desc_block` matches `block` // Ignore the scope of buffers when comparing, since we can do cache_read/write - const SBlockRealize& block = tir::GetSBlockRealize(state, block_sref); + const SBlockRealize& block = GetSBlockRealize(state, block_sref); arith::Analyzer analyzer; - auto desc_info = tir::ExtractTensorIntrinDescInfo(&analyzer, desc_func); + auto desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); return extractor->VisitStmt(block->block, desc_info.desc_block->block); } -bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::SBlockRV& block_rv, +bool CheckAutoTensorizeApplicable(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv, const tir::PrimFunc& desc_func) { AutoTensorizeComparator extractor(sch->state()->mod); return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor); } ffi::Optional GetAutoTensorizeMappingInfo( - const tir::ScheduleState& self, const tir::StmtSRef& block_sref, + const s_tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func) { AutoTensorizeComparator extractor(self->mod); if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) { @@ -2144,18 +2144,18 @@ ffi::Optional GetAutoTensorizeMappingInfo( TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.GetAutoTensorizeMappingInfo", + .def("s_tir.schedule.GetAutoTensorizeMappingInfo", [](Schedule sch, SBlockRV block, PrimFunc desc_func) { return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }) - .def("tir.schedule.HasBlock", HasBlock) - .def("tir.schedule.IsOutputBlock", + .def("s_tir.schedule.HasBlock", HasBlock) + .def("s_tir.schedule.IsOutputBlock", [](Schedule sch, SBlockRV block) { auto state = sch->state(); auto block_sref = sch->GetSRef(block); return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); }) - .def("tir.schedule.GetLoopIterType", + .def("s_tir.schedule.GetLoopIterType", [](Schedule sch, LoopRV loop) -> ffi::String { IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); if (kind == kDataPar) { @@ -2166,9 +2166,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { return "O"; } }) - .def("tir.schedule.HasIfThenElse", + .def("s_tir.schedule.HasIfThenElse", [](const Stmt& stmt) -> bool { return HasIfThenElse(stmt); }); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/analysis/layout.cc b/src/s_tir/schedule/analysis/layout.cc similarity index 98% rename from src/tir/schedule/analysis/layout.cc rename to src/s_tir/schedule/analysis/layout.cc index ddc15ab5e592..6c1feb10f706 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/s_tir/schedule/analysis/layout.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Calculate the strides of the buffer @@ -243,12 +244,12 @@ ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array

indices, ffi::Array loops, PrimExpr predicate) { arith::Analyzer analyzer; return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); }); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc similarity index 99% rename from src/tir/schedule/analysis/reducer.cc rename to src/s_tir/schedule/analysis/reducer.cc index 17e668d13ac8..547dc7d8b89b 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Pattern Matcher ********/ @@ -703,5 +704,5 @@ bool FromIdentityCombiner(const ffi::Array& identities, return false; } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/analysis/verify.cc b/src/s_tir/schedule/analysis/verify.cc similarity index 99% rename from src/tir/schedule/analysis/verify.cc rename to src/s_tir/schedule/analysis/verify.cc index 77c6bb605c8b..91b10d7d5f95 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/s_tir/schedule/analysis/verify.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class SRefTreeVerifier : public StmtVisitor { public: @@ -241,5 +242,5 @@ void VerifyCachedFlags(const ScheduleState& self) { throw; } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc similarity index 87% rename from src/tir/schedule/concrete_schedule.cc rename to src/s_tir/schedule/concrete_schedule.cc index e4a236e2ce0e..266c7ff46425 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -21,7 +21,8 @@ #include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, @@ -237,7 +238,7 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const ffi::Array& candid const ffi::Array& probs, ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); + return CreateRV(s_tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); throw; } @@ -247,8 +248,8 @@ ffi::Array ConcreteScheduleNode::SamplePerfectTile( ffi::Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); // use None RV object to denotes auto-infer tile factors. - return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, - max_innermost_factor, &decision), + return CreateRV(s_tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision), /*convert_negone_to_none=*/true); TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); throw; @@ -258,8 +259,8 @@ ffi::Array ConcreteScheduleNode::SamplePartitionedTile( const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, ffi::Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SamplePartitionedTile(&this->rand_state_, this->GetSRef(loop_rv), n, - partition_pos, innerpart_factor, &decision)); + return CreateRV(s_tir::SamplePartitionedTile(&this->rand_state_, this->GetSRef(loop_rv), n, + partition_pos, innerpart_factor, &decision)); TVM_TIR_SCHEDULE_END("sample-partitioned-tile", this->error_render_level_); throw; } @@ -268,7 +269,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV( - tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); + s_tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); throw; } @@ -323,7 +324,7 @@ SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String& name, "specify the function name explicitly, or call `work_on` to specify the function " "before using `get_sblock`."; } - ffi::Array blocks = tir::GetSBlocks(this->state_, name, gv); + ffi::Array blocks = s_tir::GetSBlocks(this->state_, name, gv); if (blocks.size() != 1) { TVM_TIR_SCHEDULE_BEGIN(); throw NotSingleResult(name, this->state_->mod, blocks); @@ -333,13 +334,13 @@ SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String& name, } ffi::Array ConcreteScheduleNode::GetLoops(const SBlockRV& block_rv) { - return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); + return CreateRV(s_tir::GetLoops(this->GetSRef(block_rv))); } ffi::Array ConcreteScheduleNode::GetChildBlocks(const SBlockRV& block_rv) { ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); - result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); + result = CreateRV(s_tir::GetChildBlocks(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); this->state_->DebugVerify(); return result; @@ -348,7 +349,7 @@ ffi::Array ConcreteScheduleNode::GetChildBlocks(const SBlockRV& block_ ffi::Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); - result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); + result = CreateRV(s_tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); this->state_->DebugVerify(); return result; @@ -356,21 +357,21 @@ ffi::Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) ffi::Array ConcreteScheduleNode::GetProducers(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); + return CreateRV(s_tir::GetProducers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_); throw; } ffi::Array ConcreteScheduleNode::GetConsumers(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); + return CreateRV(s_tir::GetConsumers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_); throw; } ffi::Array ConcreteScheduleNode::GetOutputBlocks(const SBlockRV& scope_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); + return CreateRV(s_tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); throw; } @@ -382,7 +383,7 @@ LoopRV ConcreteScheduleNode::Merge(const ffi::Array& loop_rvs) { ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::Merge(state_, loop_srefs); + result = s_tir::Merge(state_, loop_srefs); TVM_TIR_SCHEDULE_END("merge", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -393,7 +394,7 @@ LoopRV ConcreteScheduleNode::Fuse(const ffi::Array& loop_rvs, bool prese ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::Fuse(state_, loop_srefs, preserve_unit_iters); + result = s_tir::Fuse(state_, loop_srefs, preserve_unit_iters); TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -507,7 +508,7 @@ ffi::Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { throw WrongFactorError(state_->mod, ffi::GetRef(loop), true); } - results = tir::Split(state_, loop_sref, factors, preserve_unit_iters, disable_predication); + results = s_tir::Split(state_, loop_sref, factors, preserve_unit_iters, disable_predication); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(results); @@ -582,7 +583,7 @@ ffi::Array ConcreteScheduleNode::LoopPartition( if (infer_index == -1) { factors.push_back(loop->extent); } - results = tir::LoopPartition(state_, loop_sref, factors, preserve_unit_iters); + results = s_tir::LoopPartition(state_, loop_sref, factors, preserve_unit_iters); TVM_TIR_SCHEDULE_END("loop_partition", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(results); @@ -590,7 +591,7 @@ ffi::Array ConcreteScheduleNode::LoopPartition( void ConcreteScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); + s_tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_); this->state_->DebugVerify(); } @@ -598,7 +599,7 @@ void ConcreteScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { void ConcreteScheduleNode::ReorderBlockIterVar(const SBlockRV& block_rv, const ffi::Array new_order) { TVM_TIR_SCHEDULE_BEGIN(); - tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); + s_tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); TVM_TIR_SCHEDULE_END("reorder_block_iter_var", this->error_render_level_); this->state_->DebugVerify(); } @@ -606,7 +607,7 @@ void ConcreteScheduleNode::ReorderBlockIterVar(const SBlockRV& block_rv, LoopRV ConcreteScheduleNode::AddUnitLoop(const SBlockRV& block_rv) { LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); - result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); + result = CreateRV(s_tir::AddUnitLoop(state_, GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); this->state_->DebugVerify(); return result; @@ -615,7 +616,7 @@ LoopRV ConcreteScheduleNode::AddUnitLoop(const SBlockRV& block_rv) { LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); - result = CreateRV(tir::AddUnitLoop(state_, GetSRef(loop_rv))); + result = CreateRV(s_tir::AddUnitLoop(state_, GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); this->state_->DebugVerify(); return result; @@ -625,14 +626,14 @@ LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Parallel(state_, this->GetSRef(loop_rv)); + s_tir::Parallel(state_, this->GetSRef(loop_rv)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("parallel", this->error_render_level_); } void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Vectorize(state_, this->GetSRef(loop_rv)); + s_tir::Vectorize(state_, this->GetSRef(loop_rv)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_); } @@ -643,14 +644,14 @@ void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const ffi::String& thread "`vthread.x`, `vthread.y` and `vthread.z` instead"; } TVM_TIR_SCHEDULE_BEGIN(); - tir::Bind(state_, this->GetSRef(loop_rv), thread_axis); + s_tir::Bind(state_, this->GetSRef(loop_rv), thread_axis); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("bind", this->error_render_level_); } void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Unroll(state_, this->GetSRef(loop_rv)); + s_tir::Unroll(state_, this->GetSRef(loop_rv)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("unroll", this->error_render_level_); } @@ -667,8 +668,8 @@ SBlockRV ConcreteScheduleNode::CacheRead(const SBlockRV& block_rv, int read_buff consumer_block_refs.push_back(this->GetSRef(block)); } TVM_TIR_SCHEDULE_BEGIN(); - result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, - consumer_block_refs); + result = s_tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, + consumer_block_refs); TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -684,8 +685,8 @@ SBlockRV ConcreteScheduleNode::CacheWrite(const SBlockRV& block_rv, int write_bu consumer_block_refs.push_back(this->GetSRef(block)); } TVM_TIR_SCHEDULE_BEGIN(); - result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope, - consumer_block_refs); + result = s_tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope, + consumer_block_refs); TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -696,8 +697,8 @@ SBlockRV ConcreteScheduleNode::ReindexCacheRead(const SBlockRV& block_rv, int re const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, - index_map); + result = s_tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, + storage_scope, index_map); TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -708,8 +709,8 @@ SBlockRV ConcreteScheduleNode::ReindexCacheWrite(const SBlockRV& block_rv, int w const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, - storage_scope, index_map); + result = s_tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, + storage_scope, index_map); TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -720,7 +721,7 @@ ffi::Array ConcreteScheduleNode::CacheInplace(const SBlockRV& block_rv const ffi::String& storage_scope) { ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); - results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + results = s_tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_); this->state_->DebugVerify(); ffi::Array return_blocks; @@ -734,7 +735,7 @@ ffi::Array ConcreteScheduleNode::CacheIndex(const SBlockRV& block_rv, int cse_thresh) { ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); + result = s_tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_); this->state_->DebugVerify(); ffi::Array return_blocks; @@ -748,7 +749,7 @@ SBlockRV ConcreteScheduleNode::ReIndex(const SBlockRV& block_rv, int buffer_inde BufferIndexType buffer_index_type) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); + result = s_tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -760,8 +761,8 @@ SBlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const SBlockRV& blo int read_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, - storage_scope); + result = s_tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -771,8 +772,8 @@ SBlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const SBlockRV& bl int write_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, - storage_scope); + result = s_tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), + write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -789,11 +790,11 @@ void ConcreteScheduleNode::ComputeAt(const SBlockRV& block_rv, const LoopRV& loo // do nothing } else if (loop_sref.same_as(inline_mark)) { TVM_TIR_SCHEDULE_BEGIN(); - tir::ComputeInline(state_, this->GetSRef(block_rv)); + s_tir::ComputeInline(state_, this->GetSRef(block_rv)); TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_); } else { TVM_TIR_SCHEDULE_BEGIN(); - tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); + s_tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_); } this->state_->DebugVerify(); @@ -808,11 +809,11 @@ void ConcreteScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const Loop // do nothing } else if (loop_sref.same_as(inline_mark)) { TVM_TIR_SCHEDULE_BEGIN(); - tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); + s_tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_); } else { TVM_TIR_SCHEDULE_BEGIN(); - tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); + s_tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_); } this->state_->DebugVerify(); @@ -820,14 +821,14 @@ void ConcreteScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const Loop void ConcreteScheduleNode::ComputeInline(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - tir::ComputeInline(state_, this->GetSRef(block_rv)); + s_tir::ComputeInline(state_, this->GetSRef(block_rv)); TVM_TIR_SCHEDULE_END("compute-inline", this->error_render_level_); this->state_->DebugVerify(); } void ConcreteScheduleNode::ReverseComputeInline(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); + s_tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); TVM_TIR_SCHEDULE_END("reverse-compute-inline", this->error_render_level_); this->state_->DebugVerify(); } @@ -835,8 +836,8 @@ void ConcreteScheduleNode::ReverseComputeInline(const SBlockRV& block_rv) { void ConcreteScheduleNode::FuseReductionEpilogue(const SBlockRV& reduction_block_rv, const SBlockRV& epilogue_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv), - this->GetSRef(epilogue_block_rv)); + s_tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv), + this->GetSRef(epilogue_block_rv)); TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_); this->state_->DebugVerify(); } @@ -846,7 +847,7 @@ void ConcreteScheduleNode::FuseReductionEpilogue(const SBlockRV& reduction_block void ConcreteScheduleNode::StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor, int offset) { TVM_TIR_SCHEDULE_BEGIN(); - tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); + s_tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_); this->state_->DebugVerify(); } @@ -854,7 +855,7 @@ void ConcreteScheduleNode::StorageAlign(const SBlockRV& block_rv, int buffer_ind void ConcreteScheduleNode::SetScope(const SBlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) { TVM_TIR_SCHEDULE_BEGIN(); - tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); + s_tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_); this->state_->DebugVerify(); } @@ -862,7 +863,7 @@ void ConcreteScheduleNode::SetScope(const SBlockRV& block_rv, int buffer_index, void ConcreteScheduleNode::UnsafeSetDType(const SBlockRV& block_rv, int buffer_index, const ffi::String& dtype) { TVM_TIR_SCHEDULE_BEGIN(); - tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); + s_tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_); this->state_->DebugVerify(); } @@ -872,7 +873,7 @@ void ConcreteScheduleNode::UnsafeSetDType(const SBlockRV& block_rv, int buffer_i SBlockRV ConcreteScheduleNode::DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); + result = s_tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -881,7 +882,7 @@ SBlockRV ConcreteScheduleNode::DecomposeReduction(const SBlockRV& block_rv, cons SBlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); + result = s_tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); TVM_TIR_SCHEDULE_END("rfactor", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -891,7 +892,7 @@ SBlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { SBlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters); + result = s_tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); return CreateRV(result); @@ -901,7 +902,7 @@ SBlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters); + result = s_tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); return CreateRV(result); @@ -910,8 +911,8 @@ SBlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), - preserve_unit_iters); + s_tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), + preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } @@ -919,8 +920,8 @@ void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& i void ConcreteScheduleNode::Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), - preserve_unit_iters); + s_tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), + preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } @@ -987,14 +988,15 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); + s_tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, + this->CheckAndGetAnnotationValue(ann_val)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); + s_tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } @@ -1002,15 +1004,15 @@ void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& void ConcreteScheduleNode::Annotate(const SBlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Annotate(state_, this->GetSRef(block_rv), ann_key, - this->CheckAndGetAnnotationValue(ann_val)); + s_tir::Annotate(state_, this->GetSRef(block_rv), ann_key, + this->CheckAndGetAnnotationValue(ann_val)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } void ConcreteScheduleNode::Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Unannotate(state_, this->GetSRef(block_rv), ann_key); + s_tir::Unannotate(state_, this->GetSRef(block_rv), ann_key); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } @@ -1030,8 +1032,8 @@ void ConcreteScheduleNode::TransformLayout(const SBlockRV& block_rv, int buffer_ } }; auto new_index_map = Substitute(index_map, f_subst); - tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, - new_index_map, pad_value, assume_injective_transform); + s_tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, + new_index_map, pad_value, assume_injective_transform); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); } @@ -1039,7 +1041,7 @@ void ConcreteScheduleNode::TransformLayout(const SBlockRV& block_rv, int buffer_ void ConcreteScheduleNode::TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) { TVM_TIR_SCHEDULE_BEGIN(); - tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map); + s_tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_); } @@ -1048,8 +1050,8 @@ void ConcreteScheduleNode::SetAxisSeparator(const SBlockRV& block_rv, int buffer BufferIndexType buffer_index_type, const ffi::Array& axis_separators) { TVM_TIR_SCHEDULE_BEGIN(); - tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, - axis_separators); + s_tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, + axis_separators); TVM_TIR_SCHEDULE_END("set-axis-separator", this->error_render_level_); this->state_->DebugVerify(); } @@ -1059,7 +1061,7 @@ void ConcreteScheduleNode::SetAxisSeparator(const SBlockRV& block_rv, int buffer SBlockRV ConcreteScheduleNode::DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); + result = s_tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); TVM_TIR_SCHEDULE_END("decompose-padding", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -1067,7 +1069,7 @@ SBlockRV ConcreteScheduleNode::DecomposePadding(const SBlockRV& block_rv, const void ConcreteScheduleNode::PadEinsum(const SBlockRV& block_rv, const ffi::Array& padding) { TVM_TIR_SCHEDULE_BEGIN(); - tir::PadEinsum(state_, this->GetSRef(block_rv), padding); + s_tir::PadEinsum(state_, this->GetSRef(block_rv), padding); TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_); this->state_->DebugVerify(); } @@ -1076,7 +1078,7 @@ void ConcreteScheduleNode::PadEinsum(const SBlockRV& block_rv, const ffi::Array< void ConcreteScheduleNode::RollingBuffer(const SBlockRV& block_rv, int write_buffer_index) { TVM_TIR_SCHEDULE_BEGIN(); - tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); + s_tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); this->state_->DebugVerify(); } @@ -1087,7 +1089,7 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type, const ffi::Array& buf_index_array) { TVM_TIR_SCHEDULE_BEGIN(); - tir::UnsafeHideBufferAccess(state_, this->GetSRef(block_rv), buf_type, buf_index_array); + s_tir::UnsafeHideBufferAccess(state_, this->GetSRef(block_rv), buf_type, buf_index_array); TVM_TIR_SCHEDULE_END("hide-buffer-access", this->error_render_level_); this->state_->DebugVerify(); } @@ -1096,11 +1098,11 @@ void ConcreteScheduleNode::AnnotateBufferAccess(const SBlockRV& block_rv, int bu BufferIndexType buffer_index_type, const IndexMap& index_map) { TVM_TIR_SCHEDULE_BEGIN(); - tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, - index_map); + s_tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, + index_map); TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_); this->state_->DebugVerify(); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/s_tir/schedule/concrete_schedule.h similarity index 98% rename from src/tir/schedule/concrete_schedule.h rename to src/s_tir/schedule/concrete_schedule.h index 52591fad4cb2..e475eb6aefc0 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/s_tir/schedule/concrete_schedule.h @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ -#define TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ +#ifndef TVM_S_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ +#define TVM_S_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ #include #include @@ -26,7 +26,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class ConcreteScheduleNode : public ScheduleNode { friend class Schedule; @@ -395,7 +396,7 @@ inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { } } -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ +#endif // TVM_S_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ diff --git a/src/tir/schedule/error.cc b/src/s_tir/schedule/error.cc similarity index 96% rename from src/tir/schedule/error.cc rename to src/s_tir/schedule/error.cc index ce882ebbc9c7..1a8ceff888b9 100644 --- a/src/tir/schedule/error.cc +++ b/src/s_tir/schedule/error.cc @@ -19,7 +19,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; ffi::String ScheduleError::RenderReport(const ffi::String& primitive) const { IRModule mod = this->mod(); @@ -52,5 +53,5 @@ ffi::String ScheduleError::RenderReport(const ffi::String& primitive) const { return os.str(); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/error.h b/src/s_tir/schedule/error.h similarity index 93% rename from src/tir/schedule/error.h rename to src/s_tir/schedule/error.h index daea23518e77..8cd7b891af1a 100644 --- a/src/tir/schedule/error.h +++ b/src/s_tir/schedule/error.h @@ -16,15 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_ERROR_H_ -#define TVM_TIR_SCHEDULE_ERROR_H_ -#include +#ifndef TVM_S_TIR_SCHEDULE_ERROR_H_ +#define TVM_S_TIR_SCHEDULE_ERROR_H_ +#include #include #include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Error that happens during TensorIR scheduling */ class ScheduleError : public tvm::runtime::Error { @@ -83,7 +84,7 @@ class LoopPositionError : public ScheduleError { std::string primitive_; }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_ERROR_H_ +#endif // TVM_S_TIR_SCHEDULE_ERROR_H_ diff --git a/src/tir/schedule/instruction.cc b/src/s_tir/schedule/instruction.cc similarity index 96% rename from src/tir/schedule/instruction.cc rename to src/s_tir/schedule/instruction.cc index 5cf128b25201..7feb4b25ae3c 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/s_tir/schedule/instruction.cc @@ -21,7 +21,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; TVM_FFI_STATIC_INIT_BLOCK() { InstructionKindNode::RegisterReflection(); @@ -107,13 +108,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.InstructionKindGet", InstructionKind::Get) - .def("tir.schedule.Instruction", + .def("s_tir.schedule.InstructionKindGet", InstructionKind::Get) + .def("s_tir.schedule.Instruction", [](InstructionKind kind, ffi::Array inputs, ffi::Array attrs, ffi::Array outputs) -> Instruction { return Instruction(kind, inputs, attrs, outputs); }); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/instruction_traits.h b/src/s_tir/schedule/instruction_traits.h similarity index 98% rename from src/tir/schedule/instruction_traits.h rename to src/s_tir/schedule/instruction_traits.h index 93a1dd77ab64..395aab2d5ede 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/s_tir/schedule/instruction_traits.h @@ -16,11 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ -#define TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ +#ifndef TVM_S_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ +#define TVM_S_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ -#include -#include +#include +#include #include #include @@ -29,7 +29,8 @@ #include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Register an InstructionKind using a trait class @@ -45,7 +46,7 @@ namespace tir { * * // Convertible to `InstructionKindNode::FInstructionApply` * static ffi::Array ApplyToSchedule( - * const tir::Schedule& sch, + * const s_tir::Schedule& sch, * const ffi::Array& inputs, * const ffi::Array& attrs, * const ffi::Optional& decision); @@ -562,7 +563,7 @@ ffi::String PythonAPICall::Str() const { return os.str(); } -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ +#endif // TVM_S_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ diff --git a/src/tir/schedule/ir_comparator.cc b/src/s_tir/schedule/ir_comparator.cc similarity index 99% rename from src/tir/schedule/ir_comparator.cc rename to src/s_tir/schedule/ir_comparator.cc index ae476ae0a2b1..c2dc59acfec4 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/s_tir/schedule/ir_comparator.cc @@ -22,7 +22,8 @@ namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Tensorize Comparator ********/ @@ -742,5 +743,5 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { return true; } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/ir_comparator.h b/src/s_tir/schedule/ir_comparator.h similarity index 97% rename from src/tir/schedule/ir_comparator.h rename to src/s_tir/schedule/ir_comparator.h index dbf773922f48..611e77b4c65c 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/s_tir/schedule/ir_comparator.h @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ -#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ +#ifndef TVM_S_TIR_SCHEDULE_IR_COMPARATOR_H_ +#define TVM_S_TIR_SCHEDULE_IR_COMPARATOR_H_ #include #include @@ -27,7 +27,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using ExprComparator = ExprFunctor; using StmtComparator = StmtFunctor; @@ -165,7 +166,7 @@ class AutoTensorizeComparator : public TensorizeComparator { ffi::Map inner_iter_dom_map_; }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ +#endif // TVM_S_TIR_SCHEDULE_IR_COMPARATOR_H_ diff --git a/src/tir/schedule/primitive.h b/src/s_tir/schedule/primitive.h similarity index 99% rename from src/tir/schedule/primitive.h rename to src/s_tir/schedule/primitive.h index cc06f7f0d1b4..094d1f405fe8 100644 --- a/src/tir/schedule/primitive.h +++ b/src/s_tir/schedule/primitive.h @@ -16,16 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ -#define TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#ifndef TVM_S_TIR_SCHEDULE_PRIMITIVE_H_ +#define TVM_S_TIR_SCHEDULE_PRIMITIVE_H_ +#include #include -#include #include namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Schedule: Sampling ********/ /*! @@ -147,7 +148,7 @@ TVM_DLL std::vector SamplePartitionedTile( * \return The sampled loop where the input block is to be computed at */ TVM_DLL tir::StmtSRef SampleComputeLocation( - tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, + s_tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, const tir::StmtSRef& block_sref, ffi::Optional* decision); /******** Schedule: Get blocks & loops ********/ @@ -537,7 +538,7 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sr const StmtSRef& loop_sref); /*! * \brief Factor a reduction block by the specified loop - * \details See python/tvm/tir/schedule/schedule.py + * \details See python/tvm/s_tir/schedule/schedule.py * \param self The state of the schedule * \param loop_sref The loop outside block for which we want to do rfactor * \param factor_axis The position where the new dimension is placed in the new introduced rfactor @@ -741,7 +742,7 @@ TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sr */ TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map); -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#endif // TVM_S_TIR_SCHEDULE_PRIMITIVE_H_ diff --git a/src/tir/schedule/primitive/annotate.cc b/src/s_tir/schedule/primitive/annotate.cc similarity index 97% rename from src/tir/schedule/primitive/annotate.cc rename to src/s_tir/schedule/primitive/annotate.cc index 25c86431be4f..3584d703371c 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/s_tir/schedule/primitive/annotate.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key, const Any& ann_val) { @@ -118,7 +119,7 @@ struct AnnotateTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct UnannotateTraits : public UnpackedInstTraits { @@ -152,11 +153,11 @@ struct UnannotateTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); TVM_REGISTER_INST_KIND_TRAITS(UnannotateTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/s_tir/schedule/primitive/annotate_buffer_access.cc similarity index 96% rename from src/tir/schedule/primitive/annotate_buffer_access.cc rename to src/s_tir/schedule/primitive/annotate_buffer_access.cc index c358fd84d6b2..5fc00e3c364d 100644 --- a/src/tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/s_tir/schedule/primitive/annotate_buffer_access.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class AnnotateRegionRewriter : public StmtExprMutator { public: @@ -49,8 +50,8 @@ class AnnotateRegionRewriter : public StmtExprMutator { // Annotate the block with explicit_read_region or explicit_write_region ffi::Map new_annotations = n->annotations; ffi::String annotation_key = buffer_index_type_ == BufferIndexType::kWrite - ? attr::explicit_write_region - : attr::explicit_read_region; + ? tir::attr::explicit_write_region + : tir::attr::explicit_read_region; if (new_annotations.count(annotation_key)) { ffi::Array buffer_indices = Downcast>(new_annotations[annotation_key]); @@ -161,10 +162,10 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/s_tir/schedule/primitive/block_annotate.cc similarity index 96% rename from src/tir/schedule/primitive/block_annotate.cc rename to src/s_tir/schedule/primitive/block_annotate.cc index 7810eb81b6dc..0d6b2a58d0eb 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/s_tir/schedule/primitive/block_annotate.cc @@ -19,11 +19,12 @@ #include #include -#include "../../transforms/ir_utils.h" +#include "../../../tir/transforms/ir_utils.h" #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class StorageAlignAxisOutOfRangeError : public ScheduleError { public: @@ -148,14 +149,14 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { std::ostringstream os; os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " "(buffer_index, axis, factor, offset). However, the block annotation with key " - << attr::buffer_dim_align << " of the block {0} is " - << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected."; + << tir::attr::buffer_dim_align << " of the block {0} is " + << block_->annotations.at(tir::attr::buffer_dim_align) << ", which is unexpected."; return os.str(); } static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const SBlock& block) { // Get existing annotation value. - auto it = block->annotations.find(attr::buffer_dim_align); + auto it = block->annotations.find(tir::attr::buffer_dim_align); if (it != block->annotations.end()) { if (!IsValidAnnotation(block, (*it).second)) { throw StorageAlignInvalidAnnotationError(mod, block); @@ -250,7 +251,8 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind } // Step 3: Replace the block with the new annotation - SBlock new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); + SBlock new_block = + WithAnnotation(block_ptr, tir::attr::buffer_dim_align, storage_align_annotation); self->Replace(block_sref, new_block, {{ffi::GetRef(block_ptr), new_block}}); } @@ -398,7 +400,7 @@ struct StorageAlignTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct SetScopeTraits : public UnpackedInstTraits { @@ -425,7 +427,7 @@ struct SetScopeTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct UnsafeSetDTypeTraits : public UnpackedInstTraits { @@ -452,12 +454,12 @@ struct UnsafeSetDTypeTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc similarity index 98% rename from src/tir/schedule/primitive/blockize_tensorize.cc rename to src/s_tir/schedule/primitive/blockize_tensorize.cc index 1ae2b8e7bfb4..fe96387486c0 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -20,16 +20,17 @@ #include -#include "../../transforms/simplify.h" +#include "../../../tir/transforms/simplify.h" #include "../ir_comparator.h" #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; template bool UsesVar(const T& x, const Var& var) { - return UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; }); + return tir::UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; }); } Range RangeFromExtent(const PrimExpr& extent) { @@ -103,7 +104,7 @@ ffi::Array> TrivialSubspaceDivision( var_set.insert(var.get()); } return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool { - return UsesVar(expr, [&var_set](const VarNode* var) { + return tir::UsesVar(expr, [&var_set](const VarNode* var) { return var_set.count(var); // }); }; @@ -562,9 +563,9 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_u BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); self->Replace(loop_sref, blockized, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); - StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); + StmtSRef scope_root = GetScopeRoot(self, result, /*require_stage_pipeline=*/false); bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); - self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, scope_root)); + self->UpdateScopeSBlockInfo(GetSBlockRealize(self, scope_root)); self->block_info[scope_root].affine_binding = scope_block_affine_binding; return result; } @@ -738,8 +739,8 @@ StmtSRef Blockize(ScheduleState self, const ffi::Array& blocks, auto new_root = BlockizeRewriter::Rewrite(lca, blocks, blockized); self->Replace(lca, new_root, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); - StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); - self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, scope_root)); + StmtSRef scope_root = GetScopeRoot(self, result, /*require_stage_pipeline=*/false); + self->UpdateScopeSBlockInfo(GetSBlockRealize(self, scope_root)); return result; } @@ -853,7 +854,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } // Step 6: Update the cached flags. StmtSRef result = self->stmt2ref.at(block_realize->block.get()); - StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); + StmtSRef scope_root = GetScopeRoot(self, result, /*require_stage_pipeline=*/false); self->UpdateScopeSBlockInfo(scope_root->StmtAs()->body); } @@ -888,7 +889,7 @@ struct BlockizeTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct TensorizeTraits : public UnpackedInstTraits { @@ -922,11 +923,11 @@ struct TensorizeTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc similarity index 98% rename from src/tir/schedule/primitive/cache_index.cc rename to src/s_tir/schedule/primitive/cache_index.cc index 788f34e883dd..61e6ed83470d 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -18,12 +18,13 @@ */ #include -#include "../../transforms/common_subexpr_elim_tools.h" -#include "../../transforms/replace_selected_expr.h" +#include "../../../tir/transforms/common_subexpr_elim_tools.h" +#include "../../../tir/transforms/replace_selected_expr.h" #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Helper Functions/Classes ********/ @@ -520,10 +521,10 @@ struct CacheIndexTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(CacheIndexTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc similarity index 99% rename from src/tir/schedule/primitive/cache_read_write.cc rename to src/s_tir/schedule/primitive/cache_read_write.cc index 5cae3749c55f..481aec206023 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -19,12 +19,13 @@ #include -#include "../../analysis/var_use_def_analysis.h" -#include "../../transforms/ir_utils.h" +#include "../../../tir/analysis/var_use_def_analysis.h" +#include "../../../tir/transforms/ir_utils.h" #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Error Classes ********/ @@ -1658,7 +1659,7 @@ The region cover property require to hold for every of its child blocks SBlock block_; }; - for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) { + for (const auto& child_block_sref : GetChildBlocks(self, scope_root)) { const SBlockNode* child_block = TVM_SREF_TO_SBLOCK(child_block_sref); for (const BufferRegion& region : child_block->reads) { if (region->buffer.same_as(read_buffer)) { @@ -1705,7 +1706,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // info.consumer_blocks indicates which buffers should consume the cache. for (auto consumer : consumer_blocks) { info.consumer_blocks.insert(consumer); - for (auto child : tir::GetChildBlocks(self, consumer)) { + for (auto child : GetChildBlocks(self, consumer)) { info.consumer_blocks.insert(child); } } @@ -1797,7 +1798,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // info.consumer_blocks indicates which buffers should consume the cache. for (auto consumer : consumer_blocks) { info.consumer_blocks.insert(consumer); - for (auto child : tir::GetChildBlocks(self, consumer)) { + for (auto child : GetChildBlocks(self, consumer)) { info.consumer_blocks.insert(child); } } @@ -2346,7 +2347,7 @@ struct CacheReadTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct CacheWriteTraits : public UnpackedInstTraits { @@ -2380,7 +2381,7 @@ struct CacheWriteTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct CacheInplaceTraits : public UnpackedInstTraits { @@ -2409,7 +2410,7 @@ struct CacheInplaceTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct ReIndexTraits : public UnpackedInstTraits { @@ -2440,7 +2441,7 @@ struct ReIndexTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct ReindexCacheReadTraits : public UnpackedInstTraits { @@ -2470,7 +2471,7 @@ struct ReindexCacheReadTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct ReindexCacheWriteTraits : public UnpackedInstTraits { @@ -2500,7 +2501,7 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); @@ -2510,5 +2511,5 @@ TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheWriteTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc similarity index 99% rename from src/tir/schedule/primitive/compute_at.cc rename to src/s_tir/schedule/primitive/compute_at.cc index 420876637de0..4f48c1054671 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using support::NDIntSet; @@ -827,7 +828,7 @@ struct ComputeAtTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct ReverseComputeAtTraits : public UnpackedInstTraits { @@ -856,11 +857,11 @@ struct ReverseComputeAtTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(ComputeAtTraits); TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeAtTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/s_tir/schedule/primitive/compute_inline.cc similarity index 99% rename from src/tir/schedule/primitive/compute_inline.cc rename to src/s_tir/schedule/primitive/compute_inline.cc index c60954deaf59..c704fe134aa9 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/s_tir/schedule/primitive/compute_inline.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of 'A[f(i, j, k, ...)] = g(i, j, k, ...)', @@ -1724,7 +1725,7 @@ struct ComputeInlineTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct ReverseComputeInlineTraits : public UnpackedInstTraits { @@ -1747,7 +1748,7 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(ComputeInlineTraits); @@ -1777,10 +1778,10 @@ struct FuseReductionEpilogueTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(FuseReductionEpilogueTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/s_tir/schedule/primitive/decompose_padding.cc similarity index 98% rename from src/tir/schedule/primitive/decompose_padding.cc rename to src/s_tir/schedule/primitive/decompose_padding.cc index 7cf4466939eb..c1c8b751af83 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/s_tir/schedule/primitive/decompose_padding.cc @@ -18,11 +18,12 @@ */ #include -#include "../../transforms/ir_utils.h" +#include "../../../tir/transforms/ir_utils.h" #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Information used to create new padding block */ struct PaddingSBlockInfo { @@ -537,7 +538,7 @@ bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.schedule.CanDecomposePadding", [](Schedule self, SBlockRV block_rv, LoopRV loop_rv) { + "s_tir.schedule.CanDecomposePadding", [](Schedule self, SBlockRV block_rv, LoopRV loop_rv) { return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); }); } @@ -567,10 +568,10 @@ struct DecomposPaddingTraits : public UnpackedInstTraits } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(DecomposPaddingTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/s_tir/schedule/primitive/for_kind.cc similarity index 97% rename from src/tir/schedule/primitive/for_kind.cc rename to src/s_tir/schedule/primitive/for_kind.cc index 01cdb084950f..90ec40b05712 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/s_tir/schedule/primitive/for_kind.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class WrongBlockIterTypeError : public ScheduleError { public: @@ -223,7 +224,7 @@ struct ParallelTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct VectorizeTraits : public UnpackedInstTraits { @@ -246,7 +247,7 @@ struct VectorizeTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct BindTraits : public UnpackedInstTraits { @@ -271,7 +272,7 @@ struct BindTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct UnrollTraits : public UnpackedInstTraits { @@ -292,7 +293,7 @@ struct UnrollTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(ParallelTraits); @@ -300,5 +301,5 @@ TVM_REGISTER_INST_KIND_TRAITS(VectorizeTraits); TVM_REGISTER_INST_KIND_TRAITS(BindTraits); TVM_REGISTER_INST_KIND_TRAITS(UnrollTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/s_tir/schedule/primitive/get_block_loop.cc similarity index 93% rename from src/tir/schedule/primitive/get_block_loop.cc rename to src/s_tir/schedule/primitive/get_block_loop.cc index 28293624b1d8..bf13bb7795cd 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/s_tir/schedule/primitive/get_block_loop.cc @@ -20,7 +20,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; ffi::Array GetSBlocks(const ScheduleState& self, const ffi::String& name, const GlobalVar& gv) { @@ -82,17 +83,17 @@ ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& p ffi::Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - return tir::GetProducers(block_sref, self->GetSBlockScope(scope_root)); + return GetProducers(block_sref, self->GetSBlockScope(scope_root)); } ffi::Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - return tir::GetConsumers(block_sref, self->GetSBlockScope(scope_root)); + return GetConsumers(block_sref, self->GetSBlockScope(scope_root)); } ffi::Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { const auto* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); - return tir::GetOutputBlocks(self, scope_block); + return GetOutputBlocks(self, scope_block); } /******** InstructionKind Registration ********/ @@ -120,7 +121,7 @@ struct GetSBlockTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct GetLoopsTraits : public UnpackedInstTraits { @@ -144,7 +145,7 @@ struct GetLoopsTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct GetChildBlocksTraits : public UnpackedInstTraits { @@ -177,7 +178,7 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct GetProducersTraits : public UnpackedInstTraits { @@ -201,7 +202,7 @@ struct GetProducersTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct GetConsumersTraits : public UnpackedInstTraits { @@ -225,7 +226,7 @@ struct GetConsumersTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct GetOutputBlocksTraits : public UnpackedInstTraits { @@ -249,7 +250,7 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(GetSBlockTraits); @@ -259,5 +260,5 @@ TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits); TVM_REGISTER_INST_KIND_TRAITS(GetOutputBlocksTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/hide_buffer_access.cc b/src/s_tir/schedule/primitive/hide_buffer_access.cc similarity index 97% rename from src/tir/schedule/primitive/hide_buffer_access.cc rename to src/s_tir/schedule/primitive/hide_buffer_access.cc index 98805845b6ea..a2b104ec1db8 100644 --- a/src/tir/schedule/primitive/hide_buffer_access.cc +++ b/src/s_tir/schedule/primitive/hide_buffer_access.cc @@ -16,11 +16,12 @@ * specific language governing permissions and limitations * under the License. */ -#include "../../transforms/ir_utils.h" +#include "../../../tir/transforms/ir_utils.h" #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Error Classes ********/ @@ -163,10 +164,10 @@ struct UnsafeHideBufferAccessTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(UnsafeHideBufferAccessTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc similarity index 99% rename from src/tir/schedule/primitive/layout_transformation.cc rename to src/s_tir/schedule/primitive/layout_transformation.cc index 707d4e22a886..11a7903851fc 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -27,7 +27,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Planning stage prior to rewriting in TransformLayoutRewriter * @@ -1626,7 +1627,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct TransformBlockLayoutTraits : public UnpackedInstTraits { @@ -1666,7 +1667,7 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct SetAxisSeparatorTraits : public UnpackedInstTraits { @@ -1702,12 +1703,12 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits); TVM_REGISTER_INST_KIND_TRAITS(TransformBlockLayoutTraits); TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc similarity index 99% rename from src/tir/schedule/primitive/loop_transformation.cc rename to src/s_tir/schedule/primitive/loop_transformation.cc index 96ea4d2527d1..36f90ce60614 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ class BlockPredicateAppender : public StmtMutator { @@ -695,10 +696,10 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref // Replace existing loop with the newly created common block self->Replace(loop_sref, common, {}); StmtSRef scope_sref = self->stmt2ref.at(common->block.get()); - StmtSRef scope_root = tir::GetScopeRoot(self, scope_sref, /*require_stage_pipeline=*/false); + StmtSRef scope_root = GetScopeRoot(self, scope_sref, /*require_stage_pipeline=*/false); bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); // Update the SRefTree for the newly created common block - self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, scope_root)); + self->UpdateScopeSBlockInfo(GetSBlockRealize(self, scope_root)); self->block_info[scope_root].affine_binding = scope_block_affine_binding; // Collect the SRef for each partitioned loop and return @@ -1215,7 +1216,7 @@ struct SplitTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct LoopPartitionTraits : public UnpackedInstTraits { @@ -1254,7 +1255,7 @@ struct LoopPartitionTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct MergeTraits : public UnpackedInstTraits { @@ -1286,7 +1287,7 @@ struct MergeTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct FuseTraits : public UnpackedInstTraits { @@ -1320,7 +1321,7 @@ struct FuseTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct ReorderTraits : public UnpackedInstTraits { @@ -1351,7 +1352,7 @@ struct ReorderTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct AddUnitLoopTraits : public UnpackedInstTraits { @@ -1382,7 +1383,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); @@ -1392,5 +1393,5 @@ TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/s_tir/schedule/primitive/pad_einsum.cc similarity index 99% rename from src/tir/schedule/primitive/pad_einsum.cc rename to src/s_tir/schedule/primitive/pad_einsum.cc index 7fd28445a812..bffb3e6da659 100644 --- a/src/tir/schedule/primitive/pad_einsum.cc +++ b/src/s_tir/schedule/primitive/pad_einsum.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Check if buffer indices are all Vars and expr @@ -506,10 +507,10 @@ struct PadEinsumTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(PadEinsumTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc similarity index 99% rename from src/tir/schedule/primitive/read_write_at.cc rename to src/s_tir/schedule/primitive/read_write_at.cc index a8325c09e692..8b55141689be 100644 --- a/src/tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; using support::NDIntSet; @@ -384,7 +385,7 @@ struct ReadAtTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct WriteAtTraits : public UnpackedInstTraits { @@ -414,11 +415,11 @@ struct WriteAtTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc similarity index 99% rename from src/tir/schedule/primitive/reduction.cc rename to src/s_tir/schedule/primitive/reduction.cc index fafc646682bf..6e54f928c908 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -21,7 +21,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief A helper class to create a new scope that contains decomposed init body @@ -1317,7 +1318,7 @@ struct DecomposeReductionTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct RFactorTraits : public UnpackedInstTraits { @@ -1343,7 +1344,7 @@ struct RFactorTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits); @@ -1354,12 +1355,12 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.schedule.RegisterReducer", + "s_tir.schedule.RegisterReducer", [](int n_buffers, ffi::Function combiner_getter, ffi::Function identity_getter) { ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), std::move(identity_getter)); }); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/reorder_block_iter_var.cc b/src/s_tir/schedule/primitive/reorder_block_iter_var.cc similarity index 97% rename from src/tir/schedule/primitive/reorder_block_iter_var.cc rename to src/s_tir/schedule/primitive/reorder_block_iter_var.cc index 2a61734c44ef..4a76d6c51192 100644 --- a/src/tir/schedule/primitive/reorder_block_iter_var.cc +++ b/src/s_tir/schedule/primitive/reorder_block_iter_var.cc @@ -19,7 +19,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief The reorder index is not a valid permutation of @@ -140,10 +141,10 @@ struct ReorderBlockIterVarTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(ReorderBlockIterVarTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc similarity index 98% rename from src/tir/schedule/primitive/rolling_buffer.cc rename to src/s_tir/schedule/primitive/rolling_buffer.cc index 2b463207cf41..c5c41262f243 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; namespace { @@ -443,7 +444,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf self->Replace(scope_root_sref, new_scope_root, info.block_reuse); // Step 7. Regenerate block info from the root block, because `region_cover` for the target block // and `stage_pipeline` for the root block are no longer satisfied after rolling buffer injection. - self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); + self->UpdateScopeSBlockInfo(GetSBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); } struct RollingBufferTraits : public UnpackedInstTraits { @@ -468,9 +469,9 @@ struct RollingBufferTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(RollingBufferTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/primitive/sampling.cc b/src/s_tir/schedule/primitive/sampling.cc similarity index 98% rename from src/tir/schedule/primitive/sampling.cc rename to src/s_tir/schedule/primitive/sampling.cc index de09aa03dc0f..40bfff408aca 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/s_tir/schedule/primitive/sampling.cc @@ -22,7 +22,8 @@ #include "../utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; struct PrimeTable { /*! \brief The table contains prime numbers in [2, kMaxPrime) */ @@ -417,7 +418,7 @@ std::vector SamplePartitionedTile( return result; } -tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, +tir::StmtSRef SampleComputeLocation(s_tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, const StmtSRef& block_sref, ffi::Optional* decision) { // Step 1. Collect all possible compute-at locations. @@ -480,7 +481,7 @@ struct SampleCategoricalTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct SamplePerfectTileTraits : public UnpackedInstTraits { @@ -511,7 +512,7 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct SamplePartitionedTileTraits : public UnpackedInstTraits { @@ -544,7 +545,7 @@ struct SamplePartitionedTileTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; struct SampleComputeLocationTraits : public UnpackedInstTraits { @@ -573,7 +574,7 @@ struct SampleComputeLocationTraits : public UnpackedInstTraits - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); @@ -581,5 +582,5 @@ TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); TVM_REGISTER_INST_KIND_TRAITS(SamplePartitionedTileTraits); TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits); -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/s_tir/schedule/schedule.cc similarity index 67% rename from src/tir/schedule/schedule.cc rename to src/s_tir/schedule/schedule.cc index 636aa4dfc54b..39a01ff3e7af 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/s_tir/schedule/schedule.cc @@ -20,7 +20,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; @@ -51,14 +52,14 @@ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleGetMod", &ScheduleNode::mod) - .def_method("tir.schedule.ScheduleGetState", &ScheduleNode::state) - .def_method("tir.schedule.ScheduleGetTrace", &ScheduleNode::trace) - .def_method("tir.schedule.ScheduleGetFuncWorkingOn", &ScheduleNode::func_working_on) - .def_method("tir.schedule.ScheduleCopy", &ScheduleNode::Copy) - .def_method("tir.schedule.ScheduleSeed", &ScheduleNode::Seed) - .def_method("tir.schedule.ScheduleForkSeed", &ScheduleNode::ForkSeed) - .def_method("tir.schedule.ScheduleWorkOn", &ScheduleNode::WorkOn); + .def_method("s_tir.schedule.ScheduleGetMod", &ScheduleNode::mod) + .def_method("s_tir.schedule.ScheduleGetState", &ScheduleNode::state) + .def_method("s_tir.schedule.ScheduleGetTrace", &ScheduleNode::trace) + .def_method("s_tir.schedule.ScheduleGetFuncWorkingOn", &ScheduleNode::func_working_on) + .def_method("s_tir.schedule.ScheduleCopy", &ScheduleNode::Copy) + .def_method("s_tir.schedule.ScheduleSeed", &ScheduleNode::Seed) + .def_method("s_tir.schedule.ScheduleForkSeed", &ScheduleNode::ForkSeed) + .def_method("s_tir.schedule.ScheduleWorkOn", &ScheduleNode::WorkOn); } /**************** (FFI) Constructor ****************/ @@ -66,16 +67,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.SBlockRV", []() { return SBlockRV(); }) - .def("tir.schedule.LoopRV", []() { return LoopRV(); }) - .def("tir.schedule.ConcreteSchedule", + .def("s_tir.schedule.SBlockRV", []() { return SBlockRV(); }) + .def("s_tir.schedule.LoopRV", []() { return LoopRV(); }) + .def("s_tir.schedule.ConcreteSchedule", [](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level), enable_check); }) - .def("tir.schedule.TracedSchedule", + .def("s_tir.schedule.TracedSchedule", [](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Traced(mod, seed, debug_mask, @@ -89,7 +90,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.ScheduleGet", + .def("s_tir.schedule.ScheduleGet", [](Schedule self, ObjectRef obj) -> ObjectRef { if (auto loop_rv = obj.as()) { return self->Get(loop_rv.value()); @@ -104,7 +105,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { << obj->GetTypeKey() << ". Its value is: " << obj; throw; }) - .def("tir.schedule.ScheduleGetSRef", + .def("s_tir.schedule.ScheduleGetSRef", [](Schedule self, ObjectRef obj) -> ffi::Optional { if (auto loop_rv = obj.as()) { return self->GetSRef(loop_rv.value()); @@ -118,7 +119,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }) - .def("tir.schedule.ScheduleRemoveRV", [](Schedule self, ObjectRef obj) -> void { + .def("s_tir.schedule.ScheduleRemoveRV", [](Schedule self, ObjectRef obj) -> void { if (auto loop_rv = obj.as()) { return self->RemoveRV(loop_rv.value()); } @@ -137,20 +138,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleSampleCategorical", &ScheduleNode::SampleCategorical) - .def_method("tir.schedule.ScheduleSamplePerfectTile", &ScheduleNode::SamplePerfectTile) - .def_method("tir.schedule.ScheduleSamplePartitionedTile", + .def_method("s_tir.schedule.ScheduleSampleCategorical", &ScheduleNode::SampleCategorical) + .def_method("s_tir.schedule.ScheduleSamplePerfectTile", &ScheduleNode::SamplePerfectTile) + .def_method("s_tir.schedule.ScheduleSamplePartitionedTile", &ScheduleNode::SamplePartitionedTile) - .def_method("tir.schedule.ScheduleSampleComputeLocation", + .def_method("s_tir.schedule.ScheduleSampleComputeLocation", &ScheduleNode::SampleComputeLocation); } /******** (FFI) Get blocks & loops ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleGetSBlock", &ScheduleNode::GetSBlock) - .def_method("tir.schedule.ScheduleGetLoops", &ScheduleNode::GetLoops) - .def("tir.schedule.ScheduleGetChildBlocks", + .def_method("s_tir.schedule.ScheduleGetSBlock", &ScheduleNode::GetSBlock) + .def_method("s_tir.schedule.ScheduleGetLoops", &ScheduleNode::GetLoops) + .def("s_tir.schedule.ScheduleGetChildBlocks", [](Schedule self, ObjectRef rv) { if (auto block_rv = rv.as()) { return self->GetChildBlocks(block_rv.value()); @@ -162,21 +163,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { << rv->GetTypeKey() << ". Its value is: " << rv; throw; }) - .def_method("tir.schedule.ScheduleGetProducers", &ScheduleNode::GetProducers) - .def_method("tir.schedule.ScheduleGetConsumers", &ScheduleNode::GetConsumers) - .def_method("tir.schedule.ScheduleGetOutputBlocks", &ScheduleNode::GetOutputBlocks); + .def_method("s_tir.schedule.ScheduleGetProducers", &ScheduleNode::GetProducers) + .def_method("s_tir.schedule.ScheduleGetConsumers", &ScheduleNode::GetConsumers) + .def_method("s_tir.schedule.ScheduleGetOutputBlocks", &ScheduleNode::GetOutputBlocks); } /******** (FFI) Transform loops ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleMerge", &ScheduleNode::Merge) - .def_method("tir.schedule.ScheduleFuse", &ScheduleNode::Fuse) - .def_method("tir.schedule.ScheduleSplit", &ScheduleNode::Split) - .def_method("tir.schedule.ScheduleLoopPartition", &ScheduleNode::LoopPartition) - .def_method("tir.schedule.ScheduleReorder", &ScheduleNode::Reorder) - .def_method("tir.schedule.ScheduleReorderBlockIterVar", &ScheduleNode::ReorderBlockIterVar) - .def("tir.schedule.ScheduleAddUnitLoop", [](Schedule self, ObjectRef rv) -> LoopRV { + .def_method("s_tir.schedule.ScheduleMerge", &ScheduleNode::Merge) + .def_method("s_tir.schedule.ScheduleFuse", &ScheduleNode::Fuse) + .def_method("s_tir.schedule.ScheduleSplit", &ScheduleNode::Split) + .def_method("s_tir.schedule.ScheduleLoopPartition", &ScheduleNode::LoopPartition) + .def_method("s_tir.schedule.ScheduleReorder", &ScheduleNode::Reorder) + .def_method("s_tir.schedule.ScheduleReorderBlockIterVar", &ScheduleNode::ReorderBlockIterVar) + .def("s_tir.schedule.ScheduleAddUnitLoop", [](Schedule self, ObjectRef rv) -> LoopRV { if (auto loop_rv = rv.as()) { return self->AddUnitLoop(loop_rv.value()); } else if (auto block_rv = rv.as()) { @@ -192,22 +193,22 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleParallel", &ScheduleNode::Parallel) - .def_method("tir.schedule.ScheduleVectorize", &ScheduleNode::Vectorize) - .def_method("tir.schedule.ScheduleBind", &ScheduleNode::Bind) - .def_method("tir.schedule.ScheduleUnroll", &ScheduleNode::Unroll); + .def_method("s_tir.schedule.ScheduleParallel", &ScheduleNode::Parallel) + .def_method("s_tir.schedule.ScheduleVectorize", &ScheduleNode::Vectorize) + .def_method("s_tir.schedule.ScheduleBind", &ScheduleNode::Bind) + .def_method("s_tir.schedule.ScheduleUnroll", &ScheduleNode::Unroll); } /******** (FFI) Insert cache stages ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleCacheRead", &ScheduleNode::CacheRead) - .def_method("tir.schedule.ScheduleCacheWrite", &ScheduleNode::CacheWrite) - .def_method("tir.schedule.ScheduleReindexCacheRead", &ScheduleNode::ReindexCacheRead) - .def_method("tir.schedule.ScheduleReindexCacheWrite", &ScheduleNode::ReindexCacheWrite) - .def_method("tir.schedule.ScheduleCacheInplace", &ScheduleNode::CacheInplace) - .def_method("tir.schedule.ScheduleCacheIndex", &ScheduleNode::CacheIndex) - .def("tir.schedule.ScheduleReIndex", + .def_method("s_tir.schedule.ScheduleCacheRead", &ScheduleNode::CacheRead) + .def_method("s_tir.schedule.ScheduleCacheWrite", &ScheduleNode::CacheWrite) + .def_method("s_tir.schedule.ScheduleReindexCacheRead", &ScheduleNode::ReindexCacheRead) + .def_method("s_tir.schedule.ScheduleReindexCacheWrite", &ScheduleNode::ReindexCacheWrite) + .def_method("s_tir.schedule.ScheduleCacheInplace", &ScheduleNode::CacheInplace) + .def_method("s_tir.schedule.ScheduleCacheIndex", &ScheduleNode::CacheIndex) + .def("s_tir.schedule.ScheduleReIndex", [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type) { return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); @@ -217,40 +218,41 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleReadAt", &ScheduleNode::ReadAt) - .def_method("tir.schedule.ScheduleWriteAt", &ScheduleNode::WriteAt); + .def_method("s_tir.schedule.ScheduleReadAt", &ScheduleNode::ReadAt) + .def_method("s_tir.schedule.ScheduleWriteAt", &ScheduleNode::WriteAt); } /******** (FFI) Compute location ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleComputeAt", &ScheduleNode::ComputeAt) - .def_method("tir.schedule.ScheduleReverseComputeAt", &ScheduleNode::ReverseComputeAt) - .def_method("tir.schedule.ScheduleComputeInline", &ScheduleNode::ComputeInline) - .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline) - .def_method("tir.schedule.ScheduleFuseReductionEpilogue", + .def_method("s_tir.schedule.ScheduleComputeAt", &ScheduleNode::ComputeAt) + .def_method("s_tir.schedule.ScheduleReverseComputeAt", &ScheduleNode::ReverseComputeAt) + .def_method("s_tir.schedule.ScheduleComputeInline", &ScheduleNode::ComputeInline) + .def_method("s_tir.schedule.ScheduleReverseComputeInline", + &ScheduleNode::ReverseComputeInline) + .def_method("s_tir.schedule.ScheduleFuseReductionEpilogue", &ScheduleNode::FuseReductionEpilogue); } /******** (FFI) Reduction ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleDecomposeReduction", &ScheduleNode::DecomposeReduction) - .def_method("tir.schedule.ScheduleRFactor", &ScheduleNode::RFactor); + .def_method("s_tir.schedule.ScheduleDecomposeReduction", &ScheduleNode::DecomposeReduction) + .def_method("s_tir.schedule.ScheduleRFactor", &ScheduleNode::RFactor); } /******** (FFI) SBlock annotation ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleStorageAlign", &ScheduleNode::StorageAlign) - .def_method("tir.schedule.ScheduleSetScope", &ScheduleNode::SetScope) - .def_method("tir.schedule.ScheduleUnsafeSetDType", &ScheduleNode::UnsafeSetDType); + .def_method("s_tir.schedule.ScheduleStorageAlign", &ScheduleNode::StorageAlign) + .def_method("s_tir.schedule.ScheduleSetScope", &ScheduleNode::SetScope) + .def_method("s_tir.schedule.ScheduleUnsafeSetDType", &ScheduleNode::UnsafeSetDType); } /******** (FFI) Blockize & Tensorize ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.ScheduleBlockize", + .def("s_tir.schedule.ScheduleBlockize", [](Schedule self, ObjectRef target, bool preserve_unit_iters) { if (auto loop_rv = target.as()) { return self->Blockize(loop_rv.value(), preserve_unit_iters); @@ -259,7 +261,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); }) - .def("tir.schedule.ScheduleTensorize", + .def("s_tir.schedule.ScheduleTensorize", [](Schedule self, ObjectRef rv, ffi::String intrin, bool preserve_unit_iters) { if (auto block_rv = rv.as()) { self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); @@ -276,7 +278,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.ScheduleAnnotate", + .def("s_tir.schedule.ScheduleAnnotate", [](Schedule self, ObjectRef rv, const ffi::String& ann_key, const Any& ann_val) { if (auto block_rv = rv.as()) { return self->Annotate(block_rv.value(), ann_key, ann_val); @@ -288,8 +290,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { << rv->GetTypeKey() << ". Its value is: " << rv; throw; }) - .def("tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, - const ffi::String& ann_key) { + .def("s_tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, + const ffi::String& ann_key) { if (auto block_rv = rv.as()) { return self->Unannotate(block_rv.value(), ann_key); } @@ -306,7 +308,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.ScheduleTransformLayout", + .def("s_tir.schedule.ScheduleTransformLayout", [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform) { @@ -314,8 +316,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { static_cast(buffer_index_type), index_map, pad_value, assume_injective_transform); }) - .def_method("tir.schedule.ScheduleTransformBlockLayout", &ScheduleNode::TransformBlockLayout) - .def("tir.schedule.ScheduleSetAxisSeparator", + .def_method("s_tir.schedule.ScheduleTransformBlockLayout", + &ScheduleNode::TransformBlockLayout) + .def("s_tir.schedule.ScheduleSetAxisSeparator", [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type, const ffi::Array& axis_separators) { return self->SetAxisSeparator(block_rv, buffer_index, @@ -328,26 +331,27 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleDecomposePadding", &ScheduleNode::DecomposePadding) - .def_method("tir.schedule.SchedulePadEinsum", &ScheduleNode::PadEinsum); + .def_method("s_tir.schedule.ScheduleDecomposePadding", &ScheduleNode::DecomposePadding) + .def_method("s_tir.schedule.SchedulePadEinsum", &ScheduleNode::PadEinsum); } /******** (FFI) Buffer transformation ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("tir.schedule.ScheduleRollingBuffer", &ScheduleNode::RollingBuffer); + refl::GlobalDef().def_method("s_tir.schedule.ScheduleRollingBuffer", + &ScheduleNode::RollingBuffer); } /******** (FFI) Misc ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleEnterPostproc", &ScheduleNode::EnterPostproc) - .def_method("tir.schedule.ScheduleUnsafeHideBufferAccess", + .def_method("s_tir.schedule.ScheduleEnterPostproc", &ScheduleNode::EnterPostproc) + .def_method("s_tir.schedule.ScheduleUnsafeHideBufferAccess", &ScheduleNode::UnsafeHideBufferAccess); } /******** (FFI) Annotate buffer access ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.ScheduleAnnotateBufferAccess", + refl::GlobalDef().def("s_tir.schedule.ScheduleAnnotateBufferAccess", [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map) { return self->AnnotateBufferAccess( @@ -356,5 +360,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/s_tir/schedule/state.cc similarity index 99% rename from src/tir/schedule/state.cc rename to src/s_tir/schedule/state.cc index 47845be9a516..e9fdeec1c445 100644 --- a/src/tir/schedule/state.cc +++ b/src/s_tir/schedule/state.cc @@ -21,7 +21,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; TVM_FFI_STATIC_INIT_BLOCK() { ScheduleStateNode::RegisterReflection(); } @@ -1019,19 +1020,19 @@ TVM_DLL ffi::Array GetCachedFlags(const ScheduleState& self, const StmtSRe TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.ScheduleState", + .def("s_tir.schedule.ScheduleState", [](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { return ScheduleState(mod, debug_mask, enable_check); }) - .def_method("tir.schedule.ScheduleStateGetSBlockScope", &ScheduleStateNode::GetSBlockScope) - .def_method("tir.schedule.ScheduleStateReplace", &ScheduleStateNode::Replace) - .def("tir.schedule.ScheduleStateGetSRef", + .def_method("s_tir.schedule.ScheduleStateGetSBlockScope", &ScheduleStateNode::GetSBlockScope) + .def_method("s_tir.schedule.ScheduleStateReplace", &ScheduleStateNode::Replace) + .def("s_tir.schedule.ScheduleStateGetSRef", [](ScheduleState self, Stmt stmt) -> ffi::Optional { auto it = self->stmt2ref.find(stmt.get()); return it != self->stmt2ref.end() ? it->second : ffi::Optional(std::nullopt); }) - .def("tir.schedule.ScheduleStateGetCachedFlags", GetCachedFlags); + .def("s_tir.schedule.ScheduleStateGetCachedFlags", GetCachedFlags); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/trace.cc b/src/s_tir/schedule/trace.cc similarity index 96% rename from src/tir/schedule/trace.cc rename to src/s_tir/schedule/trace.cc index ee9ae29f09c6..cf1e01b0f11a 100644 --- a/src/tir/schedule/trace.cc +++ b/src/s_tir/schedule/trace.cc @@ -21,7 +21,8 @@ #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; TVM_FFI_STATIC_INIT_BLOCK() { TraceNode::RegisterReflection(); } @@ -561,7 +562,7 @@ struct EnterPostprocTraits : public UnpackedInstTraits { } template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct ::tvm::s_tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); @@ -571,13 +572,13 @@ TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.Trace", + .def("s_tir.schedule.Trace", [](ffi::Optional> insts, ffi::Optional> decisions) { return Trace(insts.value_or(ffi::Array()), decisions.value_or({})); }) - .def_method("tir.schedule.TraceGetDecision", &TraceNode::GetDecision) - .def("tir.schedule.TraceAppend", + .def_method("s_tir.schedule.TraceGetDecision", &TraceNode::GetDecision) + .def("s_tir.schedule.TraceAppend", [](Trace self, Instruction inst, ffi::Optional decision) { if (decision.defined()) { return self->Append(inst, decision.value()); @@ -585,14 +586,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { return self->Append(inst); } }) - .def_method("tir.schedule.TracePop", &TraceNode::Pop) - .def_method("tir.schedule.TraceApplyToSchedule", &TraceNode::ApplyToSchedule) - .def_method("tir.schedule.TraceAsJSON", &TraceNode::AsJSON) - .def_method("tir.schedule.TraceAsPython", &TraceNode::AsPython) - .def_method("tir.schedule.TraceWithDecision", &TraceNode::WithDecision) - .def_method("tir.schedule.TraceSimplified", &TraceNode::Simplified) - .def("tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule); + .def_method("s_tir.schedule.TracePop", &TraceNode::Pop) + .def_method("s_tir.schedule.TraceApplyToSchedule", &TraceNode::ApplyToSchedule) + .def_method("s_tir.schedule.TraceAsJSON", &TraceNode::AsJSON) + .def_method("s_tir.schedule.TraceAsPython", &TraceNode::AsPython) + .def_method("s_tir.schedule.TraceWithDecision", &TraceNode::WithDecision) + .def_method("s_tir.schedule.TraceSimplified", &TraceNode::Simplified) + .def("s_tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc similarity index 98% rename from src/tir/schedule/traced_schedule.cc rename to src/s_tir/schedule/traced_schedule.cc index 178dbbeec8ec..f6e91ebf85b5 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -19,7 +19,8 @@ #include "./traced_schedule.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, @@ -57,7 +58,7 @@ ExprRV TracedScheduleNode::SampleCategorical(const ffi::Array& candidat const ffi::Array& probs, ffi::Optional decision) { ExprRV result = - CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); + CreateRV(::tvm::s_tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{}, @@ -72,8 +73,8 @@ ffi::Array TracedScheduleNode::SamplePerfectTile( ffi::Optional> decision) { // use None RV object to denotes auto-infer tile factors. ffi::Array results = - CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, - max_innermost_factor, &decision), + CreateRV(::tvm::s_tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision), /*convert_negone_to_none=*/true); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -87,7 +88,7 @@ ffi::Array TracedScheduleNode::SamplePerfectTile( ffi::Array TracedScheduleNode::SamplePartitionedTile( const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, ffi::Optional> decision) { - ffi::Array results = CreateRV(tir::SamplePartitionedTile( + ffi::Array results = CreateRV(::tvm::s_tir::SamplePartitionedTile( &this->rand_state_, this->GetSRef(loop_rv), n, partition_pos, innerpart_factor, &decision)); static const InstructionKind& kind = InstructionKind::Get("SamplePartitionedTile"); @@ -102,8 +103,8 @@ ffi::Array TracedScheduleNode::SamplePartitionedTile( LoopRV TracedScheduleNode::SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision) { - LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, - this->GetSRef(block_rv), &decision)); + LoopRV result = CreateRV(::tvm::s_tir::SampleComputeLocation( + this->state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -800,5 +801,5 @@ void TracedScheduleNode::AnnotateBufferAccess(const SBlockRV& block_rv, int buff /*outputs=*/{})); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/s_tir/schedule/traced_schedule.h similarity index 97% rename from src/tir/schedule/traced_schedule.h rename to src/s_tir/schedule/traced_schedule.h index b3f08ed1f06a..fd0027ac8d91 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/s_tir/schedule/traced_schedule.h @@ -16,13 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ -#define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ +#ifndef TVM_S_TIR_SCHEDULE_TRACED_SCHEDULE_H_ +#define TVM_S_TIR_SCHEDULE_TRACED_SCHEDULE_H_ #include "./concrete_schedule.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; class TracedScheduleNode : public ConcreteScheduleNode { friend class Schedule; @@ -151,7 +152,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ +#endif // TVM_S_TIR_SCHEDULE_TRACED_SCHEDULE_H_ diff --git a/src/tir/schedule/transform.cc b/src/s_tir/schedule/transform.cc similarity index 97% rename from src/tir/schedule/transform.cc rename to src/s_tir/schedule/transform.cc index 0127a288698b..a6cc76203d91 100644 --- a/src/tir/schedule/transform.cc +++ b/src/s_tir/schedule/transform.cc @@ -19,11 +19,12 @@ #include -#include "../transforms/ir_utils.h" +#include "../../tir/transforms/ir_utils.h" #include "./utils.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Annotation ********/ @@ -324,13 +325,14 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, ffi::GetRef(leaf_block), ffi::GetRef(scope_block)); } -ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::SBlockRV& block_rv, +ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv, const ffi::String& intrin_name, bool allow_padding) { - ffi::Optional opt_tensorize_info = + ffi::Optional opt_tensorize_info = GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); if (!opt_tensorize_info) return std::nullopt; - const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + const TensorizeInfoNode* info = opt_tensorize_info.value().get(); if (info->block_iter_paddings.defined()) { // We have to track whether each producer or consumer is padded. // To do so, we first record all the Block's. @@ -448,7 +450,7 @@ ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir:: TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.TileWithTensorIntrin", TileWithTensorIntrin); + refl::GlobalDef().def("s_tir.schedule.TileWithTensorIntrin", TileWithTensorIntrin); } /******** BlockBufferAccessSimplifier ********/ @@ -570,8 +572,8 @@ ffi::Optional NormalizePrimFunc(Schedule sch) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.NormalizePrimFunc", NormalizePrimFunc); + refl::GlobalDef().def("s_tir.schedule.NormalizePrimFunc", NormalizePrimFunc); } -} // namespace tir +} // namespace s_tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/s_tir/schedule/transform.h similarity index 93% rename from src/tir/schedule/transform.h rename to src/s_tir/schedule/transform.h index 23a1dd0486a6..6451d69354b4 100644 --- a/src/tir/schedule/transform.h +++ b/src/s_tir/schedule/transform.h @@ -16,21 +16,22 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ -#define TVM_TIR_SCHEDULE_TRANSFORM_H_ +#ifndef TVM_S_TIR_SCHEDULE_TRANSFORM_H_ +#define TVM_S_TIR_SCHEDULE_TRANSFORM_H_ -#include -#include +#include +#include #include #include #include #include "../../arith/ir_mutator_with_analyzer.h" -#include "../ir/functor_common.h" +#include "../../tir/ir/functor_common.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /******** Annotation ********/ @@ -217,10 +218,10 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * \return LoopRV corresponding to the outermost loop of a * block tiled according to the given intrin, std::nullopt if a valid loop mapping is not found */ -ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, - const tir::SBlockRV& block_rv, - const ffi::String& intrin_name, - bool allow_padding = false); +ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv, + const ffi::String& intrin_name, + bool allow_padding = false); /******** SBlock mutation ********/ @@ -255,7 +256,7 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const BufferLoadNode* op) final; }; -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ +#endif // TVM_S_TIR_SCHEDULE_TRANSFORM_H_ diff --git a/src/tir/schedule/utils.h b/src/s_tir/schedule/utils.h similarity index 95% rename from src/tir/schedule/utils.h rename to src/s_tir/schedule/utils.h index 06752a09098e..a82978435d27 100644 --- a/src/tir/schedule/utils.h +++ b/src/s_tir/schedule/utils.h @@ -16,20 +16,20 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_TIR_SCHEDULE_UTILS_H_ -#define TVM_TIR_SCHEDULE_UTILS_H_ +#ifndef TVM_S_TIR_SCHEDULE_UTILS_H_ +#define TVM_S_TIR_SCHEDULE_UTILS_H_ #include #include #include #include +#include +#include +#include +#include #include #include #include -#include -#include -#include -#include #include #include @@ -50,7 +50,8 @@ #include "./transform.h" namespace tvm { -namespace tir { +namespace s_tir { +using namespace tvm::tir; /*! * \brief Convert an array of loop StmtSRefs to an array of loops * \param loop_srefs The loop StmtSRefs to be converted @@ -318,16 +319,17 @@ inline bool HasAnn(const StmtSRef& sref, const ffi::String& ann_key, bool ann_va * \note Before invoking this helper function, make sure that the block has only spatial and * reduction loop axes. */ -inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::SBlockRV& block_rv, - tir::LoopRV* fused_reduce_loop, +inline void ReorderAndFuseReductionLoops(const s_tir::Schedule& sch, + const s_tir::SBlockRV& block_rv, + s_tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) { - ffi::Array loops = sch->GetLoops(block_rv); + ffi::Array loops = sch->GetLoops(block_rv); ffi::Array loop_srefs; - for (const tir::LoopRV& loop_rv : loops) { + for (const s_tir::LoopRV& loop_rv : loops) { loop_srefs.push_back(sch->GetSRef(loop_rv)); } - ffi::Array new_order; + ffi::Array new_order; // Step 1. Add spatial loops. *num_spatial_loops = 0; for (size_t i = 0; i < loops.size(); ++i) { @@ -337,7 +339,7 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::SB } } // Step 2. Add reduction loops. - ffi::Array reduction_loops; + ffi::Array reduction_loops; for (size_t i = 0; i < loops.size(); ++i) { if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { new_order.push_back(loops[i]); @@ -431,7 +433,7 @@ void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array< */ int GetNumValidInstructions(const ffi::Array& insts, bool remove_postproc); -} // namespace tir +} // namespace s_tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_UTILS_H_ +#endif // TVM_S_TIR_SCHEDULE_UTILS_H_ diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index 8b9f71f4d93f..1f8d0817d572 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -35,8 +35,8 @@ #include "../../support/arena.h" #include "../../support/nd_int_set.h" #include "../../support/utils.h" -#include "../../tir/schedule/utils.h" #include "../../tir/transforms/ir_utils.h" +#include "../schedule/utils.h" namespace tvm { namespace s_tir { diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index fbcece7ff2a7..3cd2f8b8ee77 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -29,8 +29,8 @@ #include #include "../../support/utils.h" -#include "../../tir/schedule/utils.h" #include "../../tir/transforms/ir_utils.h" +#include "../schedule/utils.h" namespace tvm { namespace s_tir { diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index 7338b0887970..9a907c0caae8 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -28,8 +28,8 @@ #include "../../runtime/thread_storage_scope.h" #include "../../support/utils.h" -#include "../../tir/schedule/analysis.h" #include "../../tir/transforms/ir_utils.h" +#include "../schedule/analysis.h" namespace tvm { namespace s_tir { @@ -102,7 +102,7 @@ bool IsDominantBlock(const SBlock& scope_block, const SBlock& block) { * \param scope_block The scope block of the input block * \param analyzer The analyzer * \return A boolean indicating whether the input block is a reduction block. - * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is + * \note A similar check has been implemented in "src/s_tir/schedule/analysis.h", but that check is * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the * check again. */ diff --git a/src/s_tir/transform/manifest_shared_memory_local_stage.cc b/src/s_tir/transform/manifest_shared_memory_local_stage.cc index 4d8ae4471952..0d409d34d058 100644 --- a/src/s_tir/transform/manifest_shared_memory_local_stage.cc +++ b/src/s_tir/transform/manifest_shared_memory_local_stage.cc @@ -36,7 +36,7 @@ #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/schedule/transform.h" +#include "../schedule/transform.h" #include "tvm/tir/stmt.h" namespace tvm { diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 98bb90156569..871f82d5e80d 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -30,8 +30,8 @@ #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/schedule/utils.h" #include "../../tir/transforms/ir_utils.h" +#include "../schedule/utils.h" #include "./memhammer_rewrite_rule.h" #include "tvm/tir/stmt.h" diff --git a/src/s_tir/transform/memhammer_rewrite_rule.h b/src/s_tir/transform/memhammer_rewrite_rule.h index 974db627c4a8..90662dc17538 100644 --- a/src/s_tir/transform/memhammer_rewrite_rule.h +++ b/src/s_tir/transform/memhammer_rewrite_rule.h @@ -29,7 +29,7 @@ #include -#include "../../tir/schedule/utils.h" +#include "../schedule/utils.h" namespace tvm { namespace s_tir { diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index 06deb7934ad0..2a4a23d707f5 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -25,7 +25,7 @@ #include #include "../../arith/ir_visitor_with_analyzer.h" -#include "../schedule/error.h" +#include "../../s_tir/schedule/error.h" namespace tvm { namespace tir { @@ -38,7 +38,7 @@ struct OOBLocation { arith::IntSet shape_bounds; }; -class OOBError : public ScheduleError { +class OOBError : public s_tir::ScheduleError { public: OOBError(IRModule mod, std::vector locations) : mod_(mod), locations_(locations) {} ffi::String FastErrorString() const final { return "Out of bound memory access"; } diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 9b7442233de3..c12a14c13237 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -31,11 +31,11 @@ namespace transform { * \param max_thread_per_block The maximum number of threads per block. * \param max_threadblocks The maximum number of threadblocks. */ -void ThreadBind(tir::Schedule sch, const tir::SBlockRV& block, int64_t max_thread_per_block, +void ThreadBind(s_tir::Schedule sch, const s_tir::SBlockRV& block, int64_t max_thread_per_block, int64_t max_threadblocks = 256) { // fetch the loops - ffi::Array loops = sch->GetLoops(block); - for (const tir::LoopRV& loop : loops) { + ffi::Array loops = sch->GetLoops(block); + for (const s_tir::LoopRV& loop : loops) { // skip block if already scheduled if (sch->Get(loop)->thread_binding.defined()) { return; @@ -47,7 +47,7 @@ void ThreadBind(tir::Schedule sch, const tir::SBlockRV& block, int64_t max_threa // so loops.size() == 0 && iters.size() == 1 ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1)); - ffi::Array data_parallel_loops; + ffi::Array data_parallel_loops; // only fuse data parallel loops for (size_t i = 0; i < loops.size(); ++i) { if (iters[i]->iter_type == tir::IterVarType::kDataPar) { @@ -61,21 +61,21 @@ void ThreadBind(tir::Schedule sch, const tir::SBlockRV& block, int64_t max_threa : sch->AddUnitLoop(loops[0])); } // fuse all data parallel loops - tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false); + s_tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false); int64_t product = std::numeric_limits::max(); if (sch->Get(fused)->extent->IsInstance()) { product = sch->Get(fused)->extent.as()->value; } // schedule the fused loop if (product > max_thread_per_block * max_threadblocks) { - ffi::Array splits = sch->Split( + ffi::Array splits = sch->Split( fused, /*factors=*/{std::nullopt, Integer(max_threadblocks), Integer(max_thread_per_block)}); sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); } else { - ffi::Array splits = sch->Split( + ffi::Array splits = sch->Split( fused, /*factors=*/{std::nullopt, Integer(std::min(product, max_thread_per_block))}); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); @@ -123,8 +123,8 @@ bool IsScheduledOnGPU(const BaseFunc& func) { Pass DefaultGPUSchedule() { auto pass_func = // [=](IRModule m, PassContext pc) { - tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0, - tir::ScheduleErrorRenderLevel::kDetail); + s_tir::Schedule sch = s_tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0, + s_tir::ScheduleErrorRenderLevel::kDetail); for (const auto& [gv, func] : m->functions) { if (func->IsInstance() && !func->HasNonzeroAttr(attr::kIsScheduled) && IsScheduledOnGPU(func)) { @@ -146,8 +146,8 @@ Pass DefaultGPUSchedule() { int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); sch->WorkOn(gv->name_hint); - ffi::Array blocks = meta_schedule::SBlockCollector::Collect(sch); - for (const tir::SBlockRV& block : blocks) { + ffi::Array blocks = meta_schedule::SBlockCollector::Collect(sch); + for (const s_tir::SBlockRV& block : blocks) { auto childs = sch->GetChildBlocks(block); if (!childs.empty()) { continue; diff --git a/tests/python/dlight/test_primitives.py b/tests/python/dlight/test_primitives.py index 640a31789855..57213f410c2d 100644 --- a/tests/python/dlight/test_primitives.py +++ b/tests/python/dlight/test_primitives.py @@ -52,7 +52,7 @@ def main(p0: T.Buffer((), "int32"), T_stack: T.Buffer((T.int64(3),), "int32")): @tvm.testing.requires_cuda def test_normalize_primfunc_with_scalar(): sch = tvm.s_tir.Schedule(main) - f_normalize_prim_func = tvm.get_global_func("tir.schedule.NormalizePrimFunc") + f_normalize_prim_func = tvm.get_global_func("s_tir.schedule.NormalizePrimFunc") assert f_normalize_prim_func(sch)