Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
#ifndef TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
#define TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_

#include <tvm/ir/module.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>

namespace tvm {
Expand All @@ -36,7 +43,7 @@ class ApplyHistoryBestNode : public runtime::Object {
/*! \brief The logging function to be used */
PackedFunc logging_func;

void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); }
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("database", &database); }
/*!
* \brief Query the best entry from the database
* \param task_name The name of the task to be queried
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/meta_schedule/arg_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#define TVM_META_SCHEDULE_ARG_INFO_H_

#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/function.h>

namespace tvm {
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
#define TVM_META_SCHEDULE_BUILDER_H_

#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>

namespace tvm {
Expand Down
34 changes: 13 additions & 21 deletions include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
#ifndef TVM_META_SCHEDULE_COST_MODEL_H_
#define TVM_META_SCHEDULE_COST_MODEL_H_

#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/schedule/schedule.h>

#include <vector>

Expand Down Expand Up @@ -126,28 +134,12 @@ class PyCostModelNode : public CostModelNode {
// `f_as_string` is not visited
}

void Load(const String& path) {
ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!";
f_load(path);
}

void Save(const String& path) {
ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!";
f_save(path);
}
void Load(const String& path);
void Save(const String& path);
void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates,
const Array<RunnerResult>& results) {
ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!";
f_update(context, candidates, results);
}

const Array<RunnerResult>& results);
std::vector<double> Predict(const TuneContext& context,
const Array<MeasureCandidate>& candidates) {
ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!";
std::vector<double> result(candidates.size(), 0.0);
f_predict(context, candidates, result.data());
return result;
}
const Array<MeasureCandidate>& candidates);

static constexpr const char* _type_key = "meta_schedule.PyCostModel";
TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode);
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
#ifndef TVM_META_SCHEDULE_DATABASE_H_
#define TVM_META_SCHEDULE_DATABASE_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/trace.h>

Expand Down
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
#ifndef TVM_META_SCHEDULE_EXTRACTED_TASK_H_
#define TVM_META_SCHEDULE_EXTRACTED_TASK_H_

#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

namespace tvm {
Expand All @@ -38,7 +43,7 @@ class ExtractedTaskNode : public runtime::Object {
/*! \brief Weight of the task */
int weight;

void VisitAttrs(AttrVisitor* v) {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("task_name", &task_name);
v->Visit("mod", &mod);
v->Visit("target", &target);
Expand Down
13 changes: 8 additions & 5 deletions include/tvm/meta_schedule/feature_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_

#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -76,10 +82,7 @@ class PyFeatureExtractorNode : public FeatureExtractorNode {
}

Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
const Array<MeasureCandidate>& candidates) {
ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!";
return f_extract_from(context, candidates);
}
const Array<MeasureCandidate>& candidates) final;

static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor";
TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode);
Expand Down
11 changes: 7 additions & 4 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -94,10 +100,7 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results);
}
const Array<RunnerResult>& results);

static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
Expand Down
67 changes: 67 additions & 0 deletions include/tvm/meta_schedule/measure_candidate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_
#define TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_

#include <tvm/meta_schedule/arg_info.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

/*! \brief The schedule (with input shapes) to be measured. */
class MeasureCandidateNode : public runtime::Object {
public:
/*! \brief The schedule for measurement. */
tir::Schedule sch;
/*! \brief The argument information, e.g., (shape, dtype) for tensors. */
Array<ArgInfo> args_info;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("sch", &sch);
v->Visit("args_info", &args_info);
}

static constexpr const char* _type_key = "meta_schedule.MeasureCandidate";
TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object);
};

/*!
* \brief Managed reference to MeasureCandidateNode.
* \sa MeasureCandidateNode
*/
class MeasureCandidate : public runtime::ObjectRef {
public:
/*!
* \brief Constructor of MeasureCandidate.
* \param sch The schedule for measurement.
* \param args_info The argument information, e.g., (shape, dtype) for tensors.
*/
TVM_DLL MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_
18 changes: 8 additions & 10 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
#ifndef TVM_META_SCHEDULE_MUTATOR_H_
#define TVM_META_SCHEDULE_MUTATOR_H_

#include <tvm/node/reflection.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -89,17 +95,9 @@ class PyMutatorNode : public MutatorNode {
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyMutator's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

void InitializeWithTuneContext(const TuneContext& context) final;
Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final {
ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
return this->f_apply(trace, *rand_state);
}
support::LinearCongruentialEngine::TRandState* rand_state) final;

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
Expand Down
15 changes: 5 additions & 10 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
#ifndef TVM_META_SCHEDULE_POSTPROC_H_
#define TVM_META_SCHEDULE_POSTPROC_H_

#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
Expand Down Expand Up @@ -88,16 +91,8 @@ class PyPostprocNode : public PostprocNode {
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyPostproc's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

bool Apply(const tir::Schedule& sch) final {
ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
return this->f_apply(sch);
}
void InitializeWithTuneContext(const TuneContext& context) final;
bool Apply(const tir::Schedule& sch) final;

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@

#include <tvm/ir/expr.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

namespace tvm {
namespace meta_schedule {
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_

#include <tvm/ir/expr.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
Expand Down Expand Up @@ -90,16 +98,8 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyScheduleRule's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final {
ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!";
return this->f_apply(sch, block);
}
void InitializeWithTuneContext(const TuneContext& context) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
Expand Down
Loading