Skip to content

Commit b140a6f

Browse files
committed
Add cooldown interval logic for the profiling functional.
1 parent eb49311 commit b140a6f

File tree

17 files changed

+395
-122
lines changed

17 files changed

+395
-122
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,6 @@ endif()
394394
if(USE_PROFILER)
395395
message(STATUS "Build with profiler...")
396396

397-
add_definitions(-DUSE_PROFILER=1)
398397
tvm_file_glob(GLOB RUNTIME_GRAPH_EXECUTOR_DEBUG_SRCS src/runtime/graph_executor/debug/*.cc)
399398
list(APPEND RUNTIME_SRCS ${RUNTIME_GRAPH_EXECUTOR_DEBUG_SRCS})
400399
set_source_files_properties(${RUNTIME_GRAPH_EXECUTOR_SRCS}

include/tvm/runtime/profiling.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,16 @@ PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, i
554554
* minimum duration requirement of one `repeat`.
555555
* i.e., When the run time of one `repeat` falls below this time,
556556
* the `number` parameter will be automatically increased.
557-
* \param f_preproc The function to be executed before we excetute time evaluator.
557+
* \param cooldown_interval_ms The cooldown interval in milliseconds between the number of repeats
558+
* defined by `repeats_to_cooldown`.
559+
* \param repeats_to_cooldown The number of repeats before the
560+
* cooldown is activated.
561+
* \param f_preproc The function to be executed before we excetute time
562+
* evaluator.
558563
* \return f_timer A timer function.
559564
*/
560565
PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms,
566+
int cooldown_interval_ms, int repeats_to_cooldown,
561567
PackedFunc f_preproc = nullptr);
562568

563569
} // namespace profiling

python/tvm/auto_scheduler/testing/tune_relay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def f_per_layer(rt_mod, dev, input_data):
239239
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
240240
print("|graph_nodes| = ", len(graph_nodes))
241241
print("|graph_time| = ", len(graph_time))
242-
graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
242+
graph_nodes_time = {k: float(np.mean(v)) for k, v in zip(graph_nodes, graph_time)}
243243
for k, v in graph_nodes_time.items():
244244
print(f"{k} : {v:.3f}")
245245

python/tvm/contrib/debugger/debug_executor.py

Lines changed: 122 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,19 @@ def _run_per_layer(self):
222222
output_tensors.append(self._get_node_output(i, j))
223223
self.debug_datum.update_output_tensors(output_tensors)
224224

225-
def _run_debug(self):
225+
def _run_debug(self, number, repeat, min_repeat_ms, cooldown_interval_ms, repeats_to_cooldown):
226226
"""Execute the node specified with index will be executed.
227227
Each debug output will be copied to the buffer
228228
Time consumed for each execution will be set as debug output.
229229
"""
230230
# Get timing.
231-
self.debug_datum._time_list = [[float(t)] for t in self.run_individual(10, 1, 1)]
231+
self.debug_datum._time_list = self.run_individual(
232+
number=number,
233+
repeat=repeat,
234+
min_repeat_ms=min_repeat_ms,
235+
cooldown_interval_ms=cooldown_interval_ms,
236+
repeats_to_cooldown=repeats_to_cooldown,
237+
)
232238

233239
# Get outputs.
234240
self._run_per_layer()
@@ -259,31 +265,123 @@ def debug_get_output(self, node, out=None):
259265

260266
self._debug_get_output(node_index, out)
261267

262-
def run(self, **input_dict):
268+
# pylint: disable=arguments-differ
269+
def run(
270+
self,
271+
number=10,
272+
repeat=1,
273+
min_repeat_ms=1,
274+
cooldown_interval_ms=0,
275+
repeats_to_cooldown=1,
276+
**input_dict,
277+
):
263278
"""Run forward execution of the graph with debug
264279
265280
Parameters
266281
----------
282+
number: int, optional
283+
The number of times to run this function for taking average.
284+
We call these runs as one `repeat` of measurement.
285+
286+
repeat: int, optional
287+
The number of times to repeat the measurement.
288+
In total, the function will be invoked (1 + number x repeat) times,
289+
where the first one is warm up and will be discarded.
290+
The returned result contains `repeat` costs,
291+
each of which is an average of `number` costs.
292+
293+
min_repeat_ms: int, optional
294+
The minimum duration of one `repeat` in milliseconds.
295+
By default, one `repeat` contains `number` runs. If this parameter is set,
296+
the parameters `number` will be dynamically adjusted to meet the
297+
minimum duration requirement of one `repeat`.
298+
i.e., When the run time of one `repeat` falls below this time, the `number` parameter
299+
will be automatically increased.
300+
301+
cooldown_interval_ms: int, optional
302+
The cooldown interval in milliseconds between the number of repeats defined by
303+
`repeats_to_cooldown`.
304+
305+
repeats_to_cooldown: int, optional
306+
The number of repeats before the cooldown is activated.
307+
267308
input_dict : dict of str to NDArray
268309
List of input values to be feed to
269310
"""
270311
if input_dict:
271312
self.set_input(**input_dict)
272313

273314
# Step 1. Execute the graph
274-
self._run_debug()
315+
self._run_debug(
316+
number=number,
317+
repeat=repeat,
318+
min_repeat_ms=min_repeat_ms,
319+
cooldown_interval_ms=cooldown_interval_ms,
320+
repeats_to_cooldown=repeats_to_cooldown,
321+
)
275322
# Step 2. Dump the output tensors to the dump folder
276323
self.debug_datum.dump_output_tensor()
277324
# Step 3. Dump the Chrome trace to the dump folder
278325
self.debug_datum.dump_chrome_trace()
279326
# Step 4. Display the collected information
280327
self.debug_datum.display_debug_result()
281328

282-
def run_individual(self, number, repeat=1, min_repeat_ms=0):
283-
ret = self._run_individual(number, repeat, min_repeat_ms)
284-
return ret.strip(",").split(",") if ret else []
329+
def run_individual(
330+
self, number, repeat=1, min_repeat_ms=0, cooldown_interval_ms=0, repeats_to_cooldown=1
331+
):
332+
"""Run each operation in the graph and get the time per op for all ops.
333+
334+
number: int
335+
The number of times to run this function for taking average.
336+
We call these runs as one `repeat` of measurement.
285337
286-
def run_individual_node(self, index, number=10, repeat=1, min_repeat_ms=0):
338+
repeat: int, optional
339+
The number of times to repeat the measurement.
340+
In total, the function will be invoked (1 + number x repeat) times,
341+
where the first one is warm up and will be discarded.
342+
The returned result contains `repeat` costs,
343+
each of which is an average of `number` costs.
344+
345+
min_repeat_ms: int, optional
346+
The minimum duration of one `repeat` in milliseconds.
347+
By default, one `repeat` contains `number` runs. If this parameter is set,
348+
the parameters `number` will be dynamically adjusted to meet the
349+
minimum duration requirement of one `repeat`.
350+
i.e., When the run time of one `repeat` falls below this time, the `number` parameter
351+
will be automatically increased.
352+
353+
cooldown_interval_ms: int, optional
354+
The cooldown interval in milliseconds between the number of repeats defined by
355+
`repeats_to_cooldown`.
356+
357+
repeats_to_cooldown: int, optional
358+
The number of repeats before the cooldown is activated.
359+
360+
Returns
361+
-------
362+
A 2-dimensional array where the dimensions are: the index of the operation and
363+
the repeat of the measurement.
364+
"""
365+
ret = self._run_individual(
366+
number, repeat, min_repeat_ms, cooldown_interval_ms, repeats_to_cooldown
367+
)
368+
results = []
369+
for node_data in ret.strip(";").split(";"):
370+
results.append([])
371+
for repeat_data in node_data.strip(",").split(","):
372+
if repeat_data:
373+
results[-1].append(float(repeat_data))
374+
return results
375+
376+
def run_individual_node(
377+
self,
378+
index,
379+
number=10,
380+
repeat=1,
381+
min_repeat_ms=0,
382+
cooldown_interval_ms=0,
383+
repeats_to_cooldown=1,
384+
):
287385
"""Benchmark a single node in the serialized graph.
288386
289387
This does not do any data transfers and uses arrays already on the device.
@@ -304,27 +402,34 @@ def run_individual_node(self, index, number=10, repeat=1, min_repeat_ms=0):
304402
The returned result contains `repeat` costs,
305403
each of which is an average of `number` costs.
306404
307-
min_repeat_ms: int, optional
405+
min_repeat_ms : int, optional
308406
The minimum duration of one `repeat` in milliseconds.
309407
By default, one `repeat` contains `number` runs. If this parameter is set,
310408
the parameters `number` will be dynamically adjusted to meet the
311409
minimum duration requirement of one `repeat`.
312410
i.e., When the run time of one `repeat` falls below this time, the `number` parameter
313411
will be automatically increased.
314412
413+
cooldown_interval_ms: int, optional
414+
The cooldown interval in milliseconds between the number of repeats defined by
415+
`repeats_to_cooldown`.
416+
417+
repeats_to_cooldown: int, optional
418+
The number of repeats before the cooldown is activated.
419+
315420
Returns
316421
-------
317422
A module BenchmarkResult
318423
"""
319424
# Results are returned as serialized strings which we deserialize
320-
ret = self._run_individual_node(index, number, repeat, min_repeat_ms)
321-
answer = []
322-
for value in ret.split(","):
323-
if value.strip() == "":
324-
continue
325-
answer.append(float(value))
326-
327-
return BenchmarkResult(answer)
425+
ret = self._run_individual_node(
426+
index, number, repeat, min_repeat_ms, cooldown_interval_ms, repeats_to_cooldown
427+
)
428+
results = []
429+
for repeat_data in ret.replace(" ", "").strip(",").split(","):
430+
if repeat_data:
431+
results.append(float(repeat_data))
432+
return BenchmarkResult(results)
328433

329434
def profile(self, collectors=None, **input_dict):
330435
"""Run forward execution of the graph and collect overall and per-op

python/tvm/contrib/debugger/debug_result.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,10 @@ def get_graph_node_dtypes(self):
114114
def get_output_tensors(self):
115115
"""Get the output tensors of each operation in numpy format"""
116116
eid = 0
117-
order = 0
118117
output_tensors = {}
119-
for i, (node, time) in enumerate(zip(self._nodes_list, self._time_list)):
118+
for i, node in enumerate(self._nodes_list):
120119
num_outputs = self.get_graph_node_output_num(node)
121120
for j in range(num_outputs):
122-
order += time[0]
123121

124122
# the node name is not unique, so we need a consistent
125123
# indexing based on the list ordering in the nodes
@@ -157,7 +155,7 @@ def s_to_us(t):
157155
return t * 10**6
158156

159157
starting_times = np.zeros(len(self._time_list) + 1)
160-
starting_times[1:] = np.cumsum([times[0] for times in self._time_list])
158+
starting_times[1:] = np.cumsum([np.mean(times) for times in self._time_list])
161159

162160
def node_to_events(node, times, starting_time):
163161
return [
@@ -170,7 +168,7 @@ def node_to_events(node, times, starting_time):
170168
),
171169
ChromeTraceEvent(
172170
# Use start + duration instead of end to ensure precise timings.
173-
ts=s_to_us(times[0] + starting_time),
171+
ts=s_to_us(np.mean(times) + starting_time),
174172
tid=1,
175173
pid=1,
176174
ph="E",
@@ -205,12 +203,31 @@ def _dump_graph_json(self, graph):
205203

206204
def get_debug_result(self, sort_by_time=True):
207205
"""Return the debugger result"""
208-
header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Shape", "Inputs", "Outputs"]
209-
lines = ["---------", "---", "--------", "-------", "-----", "------", "-------"]
206+
header = [
207+
"Node Name",
208+
"Ops",
209+
"Time(us)",
210+
"Time(%)",
211+
"Shape",
212+
"Inputs",
213+
"Outputs",
214+
"Measurements(us)",
215+
]
216+
lines = [
217+
"---------",
218+
"---",
219+
"--------",
220+
"-------",
221+
"-----",
222+
"------",
223+
"-------",
224+
"----------------",
225+
]
210226
eid = 0
211227
data = []
212-
total_time = sum(time[0] for time in self._time_list)
228+
total_time = sum([np.mean(time) for time in self._time_list])
213229
for node, time in zip(self._nodes_list, self._time_list):
230+
time_mean = np.mean(time)
214231
num_outputs = self.get_graph_node_output_num(node)
215232
for j in range(num_outputs):
216233
op = node["op"]
@@ -219,11 +236,12 @@ def get_debug_result(self, sort_by_time=True):
219236
continue
220237
name = node["name"]
221238
shape = str(self._output_tensor_list[eid].shape)
222-
time_us = round(time[0] * 1e6, 3)
223-
time_percent = round(((time[0] / total_time) * 100), 3)
239+
time_us = round(time_mean * 1e6, 3)
240+
time_percent = round(((time_mean / total_time) * 100), 3)
224241
inputs = str(node["attrs"]["num_inputs"])
225242
outputs = str(node["attrs"]["num_outputs"])
226-
node_data = [name, op, time_us, time_percent, shape, inputs, outputs]
243+
measurements = str([round(repeat_data * 1e6, 3) for repeat_data in time])
244+
node_data = [name, op, time_us, time_percent, shape, inputs, outputs, measurements]
227245
data.append(node_data)
228246
eid += 1
229247

@@ -232,7 +250,7 @@ def get_debug_result(self, sort_by_time=True):
232250
data = sorted(data, key=lambda x: x[2], reverse=True)
233251
# Insert a row for total time at the end.
234252
rounded_total_time_us = round(total_time * 1e6, 3)
235-
data.append(["Total_time", "-", rounded_total_time_us, "-", "-", "-", "-", "-"])
253+
data.append(["Total_time", "-", rounded_total_time_us, "-", "-", "-", "-", "-", "-"])
236254

237255
fmt = ""
238256
for i, _ in enumerate(header):

python/tvm/contrib/graph_executor.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ def benchmark(
356356
number=5,
357357
min_repeat_ms=None,
358358
end_to_end=False,
359+
cooldown_interval_ms=0,
360+
repeats_to_cooldown=1,
359361
**kwargs,
360362
):
361363
"""Calculate runtime of a function by repeatedly calling it.
@@ -395,7 +397,7 @@ def benchmark(
395397
`number` should be increased when the runtime of the function is small (less than a 1/10
396398
of a millisecond).
397399
398-
min_repeat_ms : Optional[float]
400+
min_repeat_ms : Optional[int]
399401
If set, the inner loop will be run until it takes longer than `min_repeat_ms`
400402
milliseconds. This can be used to ensure that the function is run enough to get an
401403
accurate measurement.
@@ -405,6 +407,13 @@ def benchmark(
405407
returned tensors in the total runtime. This will give accurate timings for end to end
406408
workloads.
407409
410+
cooldown_interval_ms: Optional[int]
411+
The cooldown interval in milliseconds between the number of repeats defined by
412+
`repeats_to_cooldown`.
413+
414+
repeats_to_cooldown: Optional[int]
415+
The number of repeats before the cooldown is activated.
416+
408417
kwargs : Dict[str, Object]
409418
Named arguments to the function. These are cached before running timing code, so that
410419
data transfer costs are not counted in the runtime.
@@ -432,5 +441,11 @@ def benchmark(
432441
if kwargs:
433442
self.set_input(**kwargs)
434443
return self.module.time_evaluator(
435-
func_name, device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms
444+
func_name,
445+
device,
446+
repeat=repeat,
447+
number=number,
448+
min_repeat_ms=min_repeat_ms,
449+
cooldown_interval_ms=cooldown_interval_ms,
450+
repeats_to_cooldown=repeats_to_cooldown,
436451
)()

python/tvm/meta_schedule/testing/tune_relay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def f_per_layer(rt_mod, dev, input_data):
204204
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
205205
print("|graph_nodes| = ", len(graph_nodes))
206206
print("|graph_time| = ", len(graph_time))
207-
graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
207+
graph_nodes_time = {k: float(np.mean(v)) for k, v in zip(graph_nodes, graph_time)}
208208
for k, v in graph_nodes_time.items():
209209
print(f"{k} : {v:.3f}")
210210

0 commit comments

Comments
 (0)