Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

#include <vector>

Expand Down
20 changes: 10 additions & 10 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
#include <tvm/ir/module.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/runtime/object.h>
#include <tvm/s_tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/trace.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

#include <filesystem>
#include <memory>
Expand Down Expand Up @@ -114,7 +114,7 @@ class MeasureCandidate;
class TuningRecordNode : public runtime::Object {
public:
/*! \brief The trace tuned. */
tir::Trace trace;
s_tir::Trace trace;
/*! \brief The workload. */
Workload workload{ffi::UnsafeInit()};
/*! \brief The profiling result in seconds. */
Expand Down Expand Up @@ -166,7 +166,7 @@ class TuningRecord : public runtime::ObjectRef {
\param target The target of the tuning record.
\param args_info The argument information of the tuning record.
*/
TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
TVM_DLL explicit TuningRecord(s_tir::Trace trace, Workload workload,
ffi::Optional<ffi::Array<FloatImm>> run_secs,
ffi::Optional<Target> target,
ffi::Optional<ffi::Array<ArgInfo>> args_info);
Expand Down Expand Up @@ -251,8 +251,8 @@ class DatabaseNode : public runtime::Object {
* \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; std::nullopt if not found.
*/
virtual ffi::Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const ffi::String& workload_name);
virtual ffi::Optional<s_tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const ffi::String& workload_name);
/*!
* \brief Query the best IRModule of the given workload from the database.
* \param mod The IRModule to be searched for.
Expand Down Expand Up @@ -343,7 +343,7 @@ class PyDatabaseNode : public DatabaseNode {
* \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; std::nullopt if not found.
*/
using FQuerySchedule = ffi::TypedFunction<ffi::Optional<tir::Schedule>(
using FQuerySchedule = ffi::TypedFunction<ffi::Optional<s_tir::Schedule>(
const IRModule&, const Target&, const ffi::String&)>;
/*!
* \brief The function type of `QueryIRModule` method.
Expand Down Expand Up @@ -432,8 +432,8 @@ class PyDatabaseNode : public DatabaseNode {
}
}

ffi::Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const ffi::String& workload_name) final {
ffi::Optional<s_tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const ffi::String& workload_name) final {
if (f_query_schedule == nullptr) {
return DatabaseNode::QuerySchedule(mod, target, workload_name);
} else {
Expand Down Expand Up @@ -483,7 +483,7 @@ class Database : public runtime::ObjectRef {
* and returns a boolean indicating if the schedule is successful.
* \param mod_eq_name A string to specify the module equality testing and hashing method.
*/
TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction<bool(tir::Schedule)> schedule_fn,
TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction<bool(s_tir::Schedule)> schedule_fn,
ffi::String mod_eq_name = "structural");
/*!
* \brief Create a default database that uses JSON file for tuning records.
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/measure_candidate.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {
Expand All @@ -33,7 +33,7 @@ namespace meta_schedule {
class MeasureCandidateNode : public runtime::Object {
public:
/*! \brief The schedule for measurement. */
tir::Schedule sch;
s_tir::Schedule sch;
/*! \brief The argument information, e.g., (shape, dtype) for tensors. */
ffi::Array<ArgInfo> args_info;

Expand All @@ -57,7 +57,7 @@ class MeasureCandidate : public runtime::ObjectRef {
* \param sch The schedule for measurement.
* \param args_info The argument information, e.g., (shape, dtype) for tensors.
*/
TVM_DLL MeasureCandidate(tir::Schedule sch, ffi::Array<ArgInfo> args_info);
TVM_DLL MeasureCandidate(s_tir::Schedule sch, ffi::Array<ArgInfo> args_info);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(MeasureCandidate, ObjectRef, MeasureCandidateNode);
};

Expand Down
16 changes: 8 additions & 8 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/object.h>
#include <tvm/s_tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/trace.h>
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -58,8 +58,8 @@ class MutatorNode : public runtime::Object {
* \param rand_state The random state for mutation.
* \return None if mutator failed, otherwise return the mutated trace.
*/
virtual ffi::Optional<tir::Trace> Apply(
const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0;
virtual ffi::Optional<s_tir::Trace> Apply(
const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0;

/*!
* \brief Clone the mutator.
Expand Down Expand Up @@ -87,8 +87,8 @@ class Mutator : public runtime::ObjectRef {
* \param trace The given trace for mutation.
* \return None if mutator failed, otherwise return the mutated trace.
*/
using FApply = ffi::TypedFunction<ffi::Optional<tir::Trace>(
const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
using FApply = ffi::TypedFunction<ffi::Optional<s_tir::Trace>(
const s_tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
/*!
* \brief Clone the mutator.
* \return The cloned mutator.
Expand Down Expand Up @@ -168,8 +168,8 @@ class PyMutatorNode : public MutatorNode {
}

void InitializeWithTuneContext(const TuneContext& context) final;
ffi::Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final;
ffi::Optional<s_tir::Trace> Apply(
const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) final;
Mutator Clone() const final;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyMutator", PyMutatorNode, MutatorNode);
};
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -56,7 +56,7 @@ class PostprocNode : public runtime::Object {
* \param sch The schedule to be post processed.
* \return Whether the postprocessor was successfully applied.
*/
virtual bool Apply(const tir::Schedule& sch) = 0;
virtual bool Apply(const s_tir::Schedule& sch) = 0;

/*!
* \brief Clone the postprocessor.
Expand Down Expand Up @@ -84,7 +84,7 @@ class Postproc : public runtime::ObjectRef {
* \param sch The schedule to be post processed.
* \return Whether the postprocessor was successfully applied.
*/
using FApply = ffi::TypedFunction<bool(const tir::Schedule&)>;
using FApply = ffi::TypedFunction<bool(const s_tir::Schedule&)>;
/*!
* \brief Clone the postprocessor.
* \return The cloned postprocessor.
Expand Down Expand Up @@ -205,7 +205,7 @@ class PyPostprocNode : public PostprocNode {
}

void InitializeWithTuneContext(const TuneContext& context) final;
bool Apply(const tir::Schedule& sch) final;
bool Apply(const s_tir::Schedule& sch) final;
Postproc Clone() const final;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyPostproc", PyPostprocNode, PostprocNode);
};
Expand Down
17 changes: 9 additions & 8 deletions include/tvm/meta_schedule/schedule/cuda/thread_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#ifndef TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
#define TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_

#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

#include <algorithm>
#include <limits>
Expand All @@ -35,8 +35,8 @@ namespace meta_schedule {
* \param thread_extents The candidate thread extents.
* \return A sampler that returns a random thread extent.
*/
std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
ffi::Array<Integer> thread_extents);
std::function<s_tir::ExprRV(int64_t)> MakeFactorSampler(s_tir::Schedule sch,
ffi::Array<Integer> thread_extents);

/*!
* \brief Bind blockIdx.x and threadIdx.x to the given loop
Expand All @@ -47,9 +47,10 @@ std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
* \param get_factor A function that returns the tiling factor.
* \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest.
*/
ffi::Array<tir::LoopRV> BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, //
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<tir::ExprRV(int64_t)> get_factor = nullptr);
ffi::Array<s_tir::LoopRV> BindSpatialLoop(
s_tir::Schedule sch, s_tir::LoopRV loop, //
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<s_tir::ExprRV(int64_t)> get_factor = nullptr);

/*!
* \brief Bind the given block if it is not bound to blockIdx or threadIdx.
Expand All @@ -59,9 +60,9 @@ ffi::Array<tir::LoopRV> BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, //
* \param max_threads_per_block The maximum number of threads allowed.
* \param get_factor A function that returns the tiling factor.
*/
void BindBlockThreadIdx(tir::Schedule sch, tir::SBlockRV block, //
void BindBlockThreadIdx(s_tir::Schedule sch, s_tir::SBlockRV block, //
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<tir::ExprRV(int64_t max_extent)> get_factor = nullptr);
std::function<s_tir::ExprRV(int64_t max_extent)> get_factor = nullptr);

} // namespace meta_schedule
} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/schedule/generic/winograd.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#ifndef TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_
#define TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_

#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {
Expand All @@ -29,7 +29,7 @@ namespace meta_schedule {
* If there is a constant winograd transform matrix, inline it.
* \return The only producer block.
*/
tir::SBlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::SBlockRV block);
s_tir::SBlockRV GetWinogradProducerAndInlineConst(s_tir::Schedule sch, s_tir::SBlockRV block);

} // namespace meta_schedule
} // namespace tvm
Expand Down
11 changes: 6 additions & 5 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -60,7 +60,8 @@ class ScheduleRuleNode : public runtime::Object {
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
virtual ffi::Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::SBlockRV& block) = 0;
virtual ffi::Array<s_tir::Schedule> Apply(const s_tir::Schedule& sch,
const s_tir::SBlockRV& block) = 0;

/*!
* \brief Deep clone the schedule rule.
Expand Down Expand Up @@ -89,8 +90,8 @@ class ScheduleRule : public runtime::ObjectRef {
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
using FApply =
ffi::TypedFunction<ffi::Array<tir::Schedule>(const tir::Schedule&, const tir::SBlockRV&)>;
using FApply = ffi::TypedFunction<ffi::Array<s_tir::Schedule>(const s_tir::Schedule&,
const s_tir::SBlockRV&)>;
/*!
* \brief Get the schedule rule as string with name.
* \return The string of the schedule rule.
Expand Down Expand Up @@ -343,7 +344,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
}

void InitializeWithTuneContext(const TuneContext& context) final;
ffi::Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::SBlockRV& block) final;
ffi::Array<s_tir::Schedule> Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block) final;
ScheduleRule Clone() const final;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyScheduleRule", PyScheduleRuleNode,
ScheduleRuleNode);
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -98,7 +98,7 @@ class SearchStrategyNode : public runtime::Object {
* and reset the search strategy.
*/
virtual void PreTuning(int max_trials, int num_trials_per_iter,
const ffi::Array<tir::Schedule>& design_spaces,
const ffi::Array<s_tir::Schedule>& design_spaces,
const ffi::Optional<Database>& database,
const ffi::Optional<CostModel>& cost_model) = 0;

Expand Down Expand Up @@ -148,7 +148,7 @@ class SearchStrategy : public runtime::ObjectRef {
* \brief The function type of `PreTuning` method.
*/
using FPreTuning = ffi::TypedFunction<void(
int max_trials, int num_trials_per_iter, const ffi::Array<tir::Schedule>&,
int max_trials, int num_trials_per_iter, const ffi::Array<s_tir::Schedule>&,
const ffi::Optional<Database>&, const ffi::Optional<CostModel>&)>;
/*! \brief The function type of `PostTuning` method. */
using FPostTuning = ffi::TypedFunction<void()>;
Expand Down Expand Up @@ -255,7 +255,7 @@ class PySearchStrategyNode : public SearchStrategyNode {

void InitializeWithTuneContext(const TuneContext& context) final;
void PreTuning(int max_trials, int num_trials_per_iter,
const ffi::Array<tir::Schedule>& design_spaces,
const ffi::Array<s_tir::Schedule>& design_spaces,
const ffi::Optional<Database>& database,
const ffi::Optional<CostModel>& cost_model) final;
void PostTuning() final;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#include <tvm/meta_schedule/postproc.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/runtime/object.h>
#include <tvm/s_tir/schedule/schedule.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -105,7 +105,7 @@ class SpaceGeneratorNode : public runtime::Object {
* \param mod The module used for design space generation.
* \return The generated design spaces, i.e., schedules.
*/
virtual ffi::Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) = 0;
virtual ffi::Array<s_tir::Schedule> GenerateDesignSpace(const IRModule& mod) = 0;

/*!
* \brief Clone the space generator.
Expand Down Expand Up @@ -140,7 +140,7 @@ class SpaceGenerator : public runtime::ObjectRef {
* \param mod The module used for design space generation.
* \return The generated design spaces, i.e., schedules.
*/
using FGenerateDesignSpace = ffi::TypedFunction<ffi::Array<tir::Schedule>(const IRModule&)>;
using FGenerateDesignSpace = ffi::TypedFunction<ffi::Array<s_tir::Schedule>(const IRModule&)>;
/*!
* \brief The function type of `Clone` method.
* \return The cloned space generator.
Expand Down Expand Up @@ -232,7 +232,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
}

void InitializeWithTuneContext(const TuneContext& context) final;
ffi::Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final;
ffi::Array<s_tir::Schedule> GenerateDesignSpace(const IRModule& mod) final;
SpaceGenerator Clone() const final;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PySpaceGenerator", PySpaceGeneratorNode,
SpaceGeneratorNode);
Expand Down
Loading
Loading