Skip to content

Commit 655fc38

Browse files
Tristan Konoligemehrdadh
authored andcommitted
[ROOFLINE] Roofline analysis over RPC (apache#11252)
* [ROOFLINE] Roofline analysis over RPC Run roofline analysis on remote devices if requested. Peak flops and peak bandwidth estimation are done on the remote device. * allocate testing arrays directly on device and randomly fill * forgot to include remote * lower flops ratio, machine may be using multiple threads * forgot fill
1 parent 1c42f85 commit 655fc38

2 files changed

Lines changed: 146 additions & 18 deletions

File tree

python/tvm/utils/roofline.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,32 @@
1818
from typing import Dict, Union, Optional
1919
import numpy as np
2020

21-
from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform
21+
from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
2222
from ..target import Target
2323
from ..runtime import profiler_vm, profiling, Device, num_threads
2424
from ..script import tir as T
2525
from ..ir.instrument import pass_instrument
2626
from ..ir.expr import GlobalVar
27+
from ..rpc.base import RPC_SESS_MASK
28+
from ..rpc.client import RPCSession
29+
from ..contrib import utils
2730

2831

29-
def _create_args(mod: IRModule, dev: Device, func_name: str = "main"):
32+
def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None):
33+
if dev.device_type >= RPC_SESS_MASK:
34+
random_fill = remote.get_function("tvm.contrib.random.random_fill")
35+
else:
36+
random_fill = get_global_func("tvm.contrib.random.random_fill")
37+
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
3038
args = []
3139
for arg in mod[func_name].params:
32-
args.append(
33-
nd.array(
34-
np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
35-
device=dev,
36-
)
40+
ary = nd.empty(
41+
[x.value for x in arg.type_annotation.shape],
42+
arg.type_annotation.dtype,
43+
device=dev,
3744
)
45+
random_fill(ary)
46+
args.append(ary)
3847
return args
3948

4049

@@ -103,6 +112,7 @@ def estimate_peak_fma_flops(
103112
dev: Device,
104113
vec_width: Optional[int] = None,
105114
num_vector_registers: Optional[int] = None,
115+
remote: Optional[RPCSession] = None,
106116
) -> float:
107117
"""
108118
Estimate the maximum number of FLOP/s this target/device combo is capable
@@ -123,6 +133,9 @@ def estimate_peak_fma_flops(
123133
num_vector_registers : Optional[int]
124134
Number of vector registers on the underlying hardware. Will try to
125135
infer if no value is provided.
136+
remote : Optional[RPCSession]
137+
Remote session used to upload artifacts for runtime evaluation. Must be
138+
the same session used to create `dev`.
126139
127140
Returns
128141
-------
@@ -146,7 +159,23 @@ def estimate_peak_fma_flops(
146159
)
147160
with transform.PassContext(opt_level=3):
148161
f = build(specialized, target=target)
149-
a = nd.array(np.ones((nthreads, num_vector_registers, vec_width), dtype="float32"), device=dev)
162+
163+
# upload to remote if running over rpc
164+
if dev.device_type >= RPC_SESS_MASK:
165+
if remote is None:
166+
raise RuntimeError("A RPCSession must be provided when using a remote device.")
167+
temp = utils.tempdir()
168+
path = temp.relpath("peak_fma_flops.tar")
169+
f.export_library(path)
170+
remote.upload(path)
171+
f = remote.load_module("peak_fma_flops.tar")
172+
random_fill = remote.get_function("tvm.contrib.random.random_fill")
173+
else:
174+
random_fill = get_global_func("tvm.contrib.random.random_fill")
175+
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
176+
177+
a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
178+
random_fill(a)
150179
times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
151180
flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
152181
flop_s = flops / times.min
@@ -171,7 +200,12 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.
171200
B[i, l, j] += A[i, k, l, j]
172201

173202

174-
def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
203+
def estimate_peak_bandwidth(
204+
target: Target,
205+
dev: Device,
206+
vec_width: Optional[int] = None,
207+
remote: Optional[RPCSession] = None,
208+
) -> float:
175209
"""Estimate peak memory bandwidth of a target/device combo.
176210
177211
Peak bandwidth is estimated by running a small experiment on the underlying
@@ -187,6 +221,9 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
187221
Device to measure peak bandwidth on.
188222
vec_width : Optional[int]
189223
Vector unit width, determined from target if not supplied.
224+
remote : Optional[RPCSession]
225+
Remote session used to upload artifacts for runtime evaluation. Must be
226+
the same session used to create `dev`.
190227
191228
Returns
192229
-------
@@ -207,13 +244,30 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
207244
)
208245
with transform.PassContext(opt_level=3):
209246
f = build(specialized, target=target)
247+
248+
# upload to remote if running over rpc
249+
if dev.device_type >= RPC_SESS_MASK:
250+
if remote is None:
251+
raise RuntimeError("A RPCSession must be provided when using a remote device.")
252+
temp = utils.tempdir()
253+
path = temp.relpath("peak_bandwidth.tar")
254+
f.export_library(path)
255+
remote.upload(path)
256+
f = remote.load_module("peak_bandwidth.tar")
257+
random_fill = remote.get_function("tvm.contrib.random.random_fill")
258+
else:
259+
random_fill = get_global_func("tvm.contrib.random.random_fill")
260+
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
261+
210262
threads = num_threads()
211263
# Data size needs to be larger than last level of cache. We don't have a
212264
# way of getting cache sizes, so this number should give us a large enough
213265
# size.
214266
size = 10**8 // (4 * threads * vec_width)
215-
a = nd.array(np.ones((threads, size, 4, vec_width), dtype="float32"), device=dev)
216-
b = nd.array(np.ones((threads, vec_width, 4), dtype="float32"), device=dev)
267+
a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
268+
random_fill(a)
269+
b = nd.empty((threads, vec_width, 4), dtype="float32", device=dev)
270+
random_fill(b)
217271
times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
218272
return a.numpy().size * 4 / times.min # 4 bytes per float32
219273

@@ -241,6 +295,7 @@ def roofline_from_existing(
241295
tir_functions: Dict[GlobalVar, tir.PrimFunc],
242296
target: Target,
243297
dev: Device,
298+
remote: Optional[RPCSession] = None,
244299
) -> profiling.Report:
245300
"""Add roofline and other estimated statistics to an existing profiling report.
246301
@@ -290,6 +345,9 @@ def roofline_from_existing(
290345
TVM target that `report` was generated with.
291346
dev : Device
292347
Device that `report` was generated with.
348+
remote : Optional[RPCSession]
349+
Remote session used to upload artifacts for runtime evaluation. Must be
350+
the same session used to create `dev`.
293351
294352
Returns
295353
-------
@@ -299,8 +357,8 @@ def roofline_from_existing(
299357
:py:func:`roofline_analysis` for more information on which metrics
300358
are included.
301359
"""
302-
peak_bandwidth = estimate_peak_bandwidth(target, dev)
303-
peak_flops = estimate_peak_fma_flops(target, dev)
360+
peak_bandwidth = estimate_peak_bandwidth(target, dev, remote=remote)
361+
peak_flops = estimate_peak_fma_flops(target, dev, remote=remote)
304362

305363
ridge_point = peak_flops / peak_bandwidth
306364

@@ -346,7 +404,11 @@ def roofline_from_existing(
346404

347405

348406
def roofline_analysis(
349-
mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
407+
mod: IRModule,
408+
params: Dict[str, nd.NDArray],
409+
target: Union[str, Target],
410+
dev: Device,
411+
remote: Optional[RPCSession] = None,
350412
) -> profiling.Report:
351413
"""
352414
Create a profiling report that contains roofline and other estimated
@@ -385,6 +447,10 @@ def roofline_analysis(
385447
dev : Device
386448
Device to run on.
387449
450+
remote : Optional[RPCSession]
451+
Remote session used to upload artifacts for runtime evaluation. Must be
452+
the same session used to create `dev`.
453+
388454
Returns
389455
-------
390456
@@ -405,9 +471,18 @@ def roofline_analysis(
405471
config=pass_ctx.config,
406472
):
407473
lib = relay.vm.compile(mod, params=params, target=target)
474+
# upload to remote if running over rpc
475+
if dev.device_type >= RPC_SESS_MASK:
476+
if remote is None:
477+
raise RuntimeError("A RPCSession must be provided when using a remote device.")
478+
temp = utils.tempdir()
479+
path = temp.relpath("roofline_lib.tar")
480+
lib.mod.export_library(path)
481+
remote.upload(path)
482+
lib = remote.load_module("roofline_lib.tar")
408483
vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
409484

410-
args = _create_args(mod, dev)
485+
args = _create_args(mod, dev, remote=remote)
411486
report = vmexec.profile(*args)
412487

413-
return roofline_from_existing(report, save_tir.functions, target, dev)
488+
return roofline_from_existing(report, save_tir.functions, target, dev, remote=remote)

tests/python/unittest/test_runtime_profiling.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,23 @@ def test_estimate_peak_fma_flops(target, dev):
267267
flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
268268
# Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
269269
assert (
270-
flops > 10**9 * tvm.runtime.num_threads() and flops < 10**14
271-
), f"FLOP/s should be between 10^9 * num_threads and 10^14, but it is {flops}"
270+
flops > 10**9 and flops < 10**14
271+
), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
272272

273273

274+
def test_estimate_peak_fma_flops_rpc():
275+
target = "llvm -mattr=+fma,+avx2"
276+
server = rpc.Server(key="profiling")
277+
remote = rpc.connect("127.0.0.1", server.port, key="profiling")
278+
dev = remote.device(target)
279+
flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev, remote=remote)
280+
# Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
281+
assert (
282+
flops > 10**9 and flops < 10**14
283+
), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
284+
285+
286+
@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
274287
@tvm.testing.parametrize_targets("llvm")
275288
def test_estimate_peak_bandwidth(target, dev):
276289
# This test uses vectorized instructions so we need a target that supports them
@@ -284,6 +297,20 @@ def test_estimate_peak_bandwidth(target, dev):
284297
), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
285298

286299

300+
@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
301+
def test_estimate_peak_bandwidth_rpc():
302+
target = "llvm -mattr=+fma,+avx2"
303+
server = rpc.Server(key="profiling")
304+
remote = rpc.connect("127.0.0.1", server.port, key="profiling")
305+
dev = remote.device(target)
306+
bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev, remote=remote)
307+
# Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
308+
# GB/s, so this should leave enough wiggle room.
309+
assert (
310+
bandwidth > 10**9 and bandwidth < 10**12
311+
), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
312+
313+
287314
@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
288315
@tvm.testing.parametrize_targets("llvm")
289316
def test_roofline_analysis(target, dev):
@@ -304,6 +331,32 @@ def test_roofline_analysis(target, dev):
304331
assert call["Percent of Theoretical Optimal"].ratio >= 0
305332

306333

334+
@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
335+
def test_roofline_analysis_rpc():
336+
target = "llvm"
337+
338+
a = relay.var("a", relay.TensorType((512, 512), "float32"))
339+
b = relay.var("b", relay.TensorType((512, 512), "float32"))
340+
c = relay.nn.dense(a, b)
341+
mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
342+
params = {}
343+
344+
server = rpc.Server(key="profiling")
345+
remote = rpc.connect("127.0.0.1", server.port, key="profiling")
346+
dev = remote.device(target)
347+
348+
report = tvm.utils.roofline_analysis(mod, params, target, dev, remote=remote)
349+
350+
assert "Bound" in report.table()
351+
assert "Percent of Theoretical Optimal" in report.table()
352+
for call in report.calls:
353+
if "Percent of Theoretical Optimal" in call:
354+
# Ideally we'd like a little tighter bound here, but it is hard to
355+
# know how well this dense will perform without tuning. And we
356+
# don't have an operator that uses a specific number of flops.
357+
assert call["Percent of Theoretical Optimal"].ratio >= 0
358+
359+
307360
if __name__ == "__main__":
308361
import sys
309362
import pytest

0 commit comments

Comments
 (0)