66 * to you under the Apache License, Version 2.0 (the
77 * "License"); you may not use this file except in compliance
88 * with the License. You may obtain a copy of the License at
9- *
9+ *
1010 * http://www.apache.org/licenses/LICENSE-2.0
11- *
11+ *
1212 * Unless required by applicable law or agreed to in writing,
1313 * software distributed under the License is distributed on an
1414 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
2424#include < tvm/runtime/packed_func.h>
2525#include < tvm/runtime/registry.h>
2626#include < tvm/runtime/ndarray.h>
27+
2728#include < chrono>
29+ #include < sstream>
2830#include " ../graph_runtime.h"
2931
3032namespace tvm {
@@ -39,40 +41,23 @@ namespace runtime {
3941class GraphRuntimeDebug : public GraphRuntime {
4042 public:
4143 /* !
42- * \brief Run each operation and get the output.
43- * \param index The index of op which needs to be run.
44- * \return the elapsed time.
45- */
46- double DebugRun (size_t index) {
47- CHECK (index < op_execs_.size ());
48- TVMContext ctx = data_entry_[entry_id (index, 0 )]->ctx ;
49- auto tbegin = std::chrono::high_resolution_clock::now ();
50- if (op_execs_[index]) {
51- op_execs_[index]();
52- }
53- TVMSynchronize (ctx.device_type , ctx.device_id , nullptr );
54- auto tend = std::chrono::high_resolution_clock::now ();
55- double time = std::chrono::duration_cast<std::chrono::duration<double > >(
56- tend - tbegin).count ();
57- return time;
58- }
59-
60- /* !
61- * \brief Run each operation in the graph and print out the runtime per op.
44+ * \brief Run each operation in the graph and get the time per op for all ops.
6245 * \param number The number of times to run this function for taking average.
6346 * \param repeat The number of times to repeat the measurement.
64- In total, the function will be invoked (1 + number x repeat) times,
65- where the first one is warmed up and will be discarded in case
66- there is lazy initialization.
47+ * In total, the function will be invoked (1 + number x repeat) times,
48+ * where the first one is warmed up and will be discarded in case
49+ * there is lazy initialization.
6750 * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
68- By default, one `repeat` contains `number` runs. If this parameter is set,
69- the parameters `number` will be dynamically adjusted to meet the
70- minimum duration requirement of one `repeat`.
51+ * By default, one `repeat` contains `number` runs. If this parameter is set,
52+ * the parameters `number` will be dynamically adjusted to meet the
53+ * minimum duration requirement of one `repeat`.
54+ * \return Comma seperated string containing the elapsed time per op for the last
55+ * iteration only, because returning a long string over rpc can be expensive.
7156 */
72- void RunIndividual (int number, int repeat, int min_repeat_ms) {
57+ std::string RunIndividual (int number, int repeat, int min_repeat_ms) {
7358 // warmup run
7459 GraphRuntime::Run ();
75-
60+ std::ostringstream os;
7661 std::vector<double > time_per_op (op_execs_.size (), 0 );
7762 for (int i = 0 ; i < repeat; ++i) {
7863 std::chrono::time_point<
@@ -96,7 +81,7 @@ class GraphRuntimeDebug : public GraphRuntime {
9681 auto op_tend = std::chrono::high_resolution_clock::now ();
9782 double op_duration = std::chrono::duration_cast<
9883 std::chrono::duration<double > >(op_tend - op_tbegin).count ();
99- time_per_op[index] += op_duration * 1000 ; // ms
84+ time_per_op[index] += op_duration * 1e6 ; // us
10085 }
10186 }
10287 }
@@ -105,16 +90,20 @@ class GraphRuntimeDebug : public GraphRuntime {
10590 (tend - tbegin).count () * 1000 ;
10691 } while (duration_ms < min_repeat_ms);
10792
108- LOG (INFO) << " Repeat : " << i;
93+ LOG (INFO) << " Iteration : " << i;
10994 int op = 0 ;
11095 for (size_t index = 0 ; index < time_per_op.size (); index++) {
11196 if (op_execs_[index]) {
11297 time_per_op[index] /= number;
11398 LOG (INFO) << " Op #" << op++ << " " << GetNodeName (index) << " : "
114- << time_per_op[index] << " ms /iter" ;
99+ << time_per_op[index] << " us /iter" ;
115100 }
116101 }
117102 }
103+ for (size_t index = 0 ; index < time_per_op.size (); index++) {
104+ os << time_per_op[index] << " ," ;
105+ }
106+ return os.str ();
118107 }
119108
120109 /* !
@@ -182,11 +171,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(
182171 const std::string& name,
183172 const std::shared_ptr<ModuleNode>& sptr_to_self) {
184173 // return member functions during query.
185- if (name == " debug_run" ) {
186- return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
187- *rv = this ->DebugRun (static_cast <size_t >(args[0 ].operator int64_t ()));
188- });
189- } else if (name == " get_output_by_layer" ) {
174+ if (name == " get_output_by_layer" ) {
190175 return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
191176 *rv = this ->GetOutputByLayer (args[0 ], args[1 ]);
192177 });
@@ -206,7 +191,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(
206191 CHECK_GT (number, 0 );
207192 CHECK_GT (repeat, 0 );
208193 CHECK_GE (min_repeat_ms, 0 );
209- this ->RunIndividual (number, repeat, min_repeat_ms);
194+ *rv = this ->RunIndividual (number, repeat, min_repeat_ms);
210195 });
211196 } else {
212197 return GraphRuntime::GetFunction (name, sptr_to_self);
0 commit comments