Skip to content

Commit d2a7f93

Browse files
author
Tristan Konolige
authored
[ROOFLINE] Calculate roofline from existing TIR PrimFunc (#11238)
Refactor roofline_analysis to use a pass instrument to save TIR code from compilation for feature extraction. This should support different compilation pipelines and avoids recompiling the module twice.
1 parent 9e404f0 commit d2a7f93

File tree

2 files changed

+146
-48
lines changed

2 files changed

+146
-48
lines changed

python/tvm/utils/roofline.py

Lines changed: 145 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from ..target import Target
2323
from ..runtime import profiler_vm, profiling, Device, num_threads
2424
from ..script import tir as T
25+
from ..ir.instrument import pass_instrument
26+
from ..ir.expr import GlobalVar
2527

2628

2729
def _create_args(mod: IRModule, dev: Device, func_name: str = "main"):
@@ -36,16 +38,6 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main"):
3638
return args
3739

3840

39-
def _estimated_features(mod: IRModule, params: Dict[str, nd.NDArray], target: Target):
40-
comp = relay.vm.VMCompiler()
41-
mod, params = comp.optimize(mod, params=params, target=target)
42-
return {
43-
prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
44-
for name, prim in mod.functions.items()
45-
if isinstance(prim, tir.PrimFunc)
46-
}
47-
48-
4941
def _detect_vec_width_registers(
5042
target: Target, vec_width: Optional[int], num_vector_registers: Optional[int]
5143
):
@@ -226,60 +218,98 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
226218
return a.numpy().size * 4 / times.min # 4 bytes per float32
227219

228220

229-
def roofline_analysis(
230-
mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
221+
@pass_instrument
222+
class SaveLoweredTIR:
223+
"""Save TIR functions from right before final lowering. Right now this
224+
means right before tir.MakePackedAPI."""
225+
226+
def __init__(self):
227+
self.functions = {}
228+
self.done = False
229+
230+
def run_after_pass(self, mod, info):
231+
if not self.done:
232+
if info.name == "tir.MakePackedAPI":
233+
self.done = True
234+
else:
235+
for v, func in mod.functions.items():
236+
self.functions[v] = func
237+
238+
239+
def roofline_from_existing(
240+
report: profiling.Report,
241+
tir_functions: Dict[GlobalVar, tir.PrimFunc],
242+
target: Target,
243+
dev: Device,
231244
) -> profiling.Report:
232-
"""
233-
Create a profiling report that contains roofline and other estimated
234-
statistics from running a module on the VM.
245+
"""Add roofline and other estimated statistics to an existing profiling report.
235246
236-
These statistics are calculated by analyzing the lowered TIR of each
237-
operator, so they are estimates of the true values. The statistics are:
238-
- Bound: Is the operator memory or compute bound. This is computed by
239-
assuming that the operator could perfectly cache all loads -- each byte
240-
of memory is only loaded once.
241-
- Percent of Theoretical Optimal: What percent of theoretical optimal for
242-
the bound. i.e. percent of peak memory bandwidth if memory bound,
243-
percent of peak FLOP/s if compute bound.
244-
- Loaded Bytes: estimation of the number of bytes loaded from main memory.
245-
- Estimated Flops: estimated number of floating point operations.
246-
- Arithmetic Intensity: ratio of FLOPs per byte of data.
247-
- FLOP/s: floating point operations per second.
248-
- Bandwidth: Number of bytes loaded per second.
247+
:py:func:`roofline_analysis` should always be used instead of this function
248+
unless you need a custom compilation pipeline.
249249
250-
Parameters
251-
----------
252-
mod : IRModule
253-
Uncompiled input module>
250+
Calculating roofline statistics requires features extracted the TIR
251+
functions in addition to per-operator runtime information (`report`) of the
252+
same TIR features. The features and TIR functions are not included with the
253+
compiled library used to generate the per-operator runtime. It is essential
254+
that the per-operator information comes from the exact same compilation
255+
pipeline as the TIR functions.
254256
255-
params : Dict[str, nd.NDArray]
256257
257-
target : Union[str, Target]
258-
Target to run on.
258+
Example
259+
-------
260+
261+
..code: : python
262+
263+
import tvm
264+
import tvm.relay
265+
266+
mod, params = tvm.relay.testing.mlp.get_workload()
267+
268+
# it is recommended to use SaveLoweredTIR to get out the tir primfuncs
269+
save_tir = tvm.utils.roofline.SaveLoweredTIR()
270+
with tvm.transform.PassContext(opt_level=3, pass_instrument=[save_tir]):
271+
lib = relay.vm.compile(mod, params=params, target=target)
272+
273+
vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
274+
report = vmexec.profile(*inputs)
275+
276+
roofline_report = roofline_from_existing(report, save_tir.functions, target, dev)
259277
278+
279+
Parameters
280+
----------
281+
report : Report
282+
Existing profiling report from :py:method:`VirtualMachineProfiler.profile`.
283+
tir_functions : Dict[GlobalVar, PrimFunc]
284+
TIR primfuncs from the module run to generate `report`. It is nessesary
285+
that these functions come before the `tir.MakePackedAPI` pass and are
286+
compatible with auto_scheduler featurization.
287+
:py:class:`SaveLoweredTIR` is the recommended way to collect these
288+
functions.
289+
target : Target
290+
TVM target that `report` was generated with.
260291
dev : Device
261-
Device to run on.
292+
Device that `report` was generated with.
262293
263294
Returns
264295
-------
265-
266-
report : profiling.Report
267-
Profiling report which includes the estimated statistics.
296+
profiling.Report
297+
New profiling report that includes all information from `report`
298+
along with additional roofline metrics. See
299+
:py:func:`roofline_analysis` for more information on which metrics
300+
are included.
268301
"""
269-
if isinstance(target, str):
270-
target = Target(target)
271302
peak_bandwidth = estimate_peak_bandwidth(target, dev)
272303
peak_flops = estimate_peak_fma_flops(target, dev)
273304

274305
ridge_point = peak_flops / peak_bandwidth
275306

276-
all_features = _estimated_features(mod, params, target)
277-
278-
lib = relay.vm.compile(mod, params=params, target=target)
279-
vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
307+
all_features = {
308+
prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
309+
for name, prim in tir_functions.items()
310+
if isinstance(prim, tir.PrimFunc) and "hash" in prim.attrs.keys()
311+
}
280312

281-
args = _create_args(mod, dev)
282-
report = vmexec.profile(*args)
283313
new_calls = []
284314
for call in report.calls:
285315
if "Hash" in call.keys():
@@ -313,3 +343,71 @@ def roofline_analysis(
313343
else:
314344
new_calls.append(call)
315345
return profiling.Report(new_calls, report.device_metrics)
346+
347+
348+
def roofline_analysis(
349+
mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
350+
) -> profiling.Report:
351+
"""
352+
Create a profiling report that contains roofline and other estimated
353+
statistics from running a module on the VM.
354+
355+
The roofline model measures how close a operator gets to best possible
356+
memory bandwidth or FLOP/s depending on whether it is memory or compute
357+
bound. This computation uses the runtime of the operator along with two
358+
numbers extracted from the TIR code: bytes of memory touched and number of
359+
floating point operations.
360+
361+
These statistics are calculated by analyzing the lowered TIR of each
362+
operator, so they are estimates of the true values. The statistics are:
363+
- Bound: Is the operator memory or compute bound. This is computed by
364+
assuming that the operator could perfectly cache all loads -- each byte
365+
of memory is only loaded once.
366+
- Percent of Theoretical Optimal: What percent of theoretical optimal for
367+
the bound. i.e. percent of peak memory bandwidth if memory bound,
368+
percent of peak FLOP/s if compute bound.
369+
- Loaded Bytes: estimation of the number of bytes loaded from main memory.
370+
- Estimated Flops: estimated number of floating point operations.
371+
- Arithmetic Intensity: ratio of FLOPs per byte of data.
372+
- FLOP/s: floating point operations per second.
373+
- Bandwidth: Number of bytes loaded per second.
374+
375+
Parameters
376+
----------
377+
mod : IRModule
378+
Uncompiled input module>
379+
380+
params : Dict[str, nd.NDArray]
381+
382+
target : Union[str, Target]
383+
Target to run on.
384+
385+
dev : Device
386+
Device to run on.
387+
388+
Returns
389+
-------
390+
391+
report : profiling.Report
392+
Profiling report which includes the estimated statistics.
393+
"""
394+
if isinstance(target, str):
395+
target = Target(target)
396+
397+
save_tir = SaveLoweredTIR()
398+
# copy existing context but add our instrument
399+
pass_ctx = transform.PassContext.current()
400+
with transform.PassContext(
401+
opt_level=pass_ctx.opt_level,
402+
required_pass=pass_ctx.required_pass,
403+
disabled_pass=pass_ctx.disabled_pass,
404+
instruments=list(pass_ctx.instruments) + [save_tir],
405+
config=pass_ctx.config,
406+
):
407+
lib = relay.vm.compile(mod, params=params, target=target)
408+
vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
409+
410+
args = _create_args(mod, dev)
411+
report = vmexec.profile(*args)
412+
413+
return roofline_from_existing(report, save_tir.functions, target, dev)

src/auto_scheduler/feature.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
740740
// TODO(tkonolige): add arithmetic counts from this statement to counts of inner stores.
741741
ana_.Bind(node->var, node->value);
742742
ICHECK(variable_definition_stack_.size() > 0)
743-
<< "Variable definition out size of a for loop is not handled by feature extraction";
743+
<< "Variable definition outside of a for loop is not handled by feature extraction";
744744
variable_definition_stack_.back().push_back(std::make_tuple(node->var, node->value));
745745
StmtExprVisitor::VisitStmt_(node);
746746
}

0 commit comments

Comments
 (0)