Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 0 additions & 212 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,55 +146,6 @@ def run_and_verify_func(config, target="cuda", run_module=True, data_type="float
assert_result_dict_holds(result_dict, data_type)


def run_and_verify_model(model, run_module):
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model

def check_trt_used(mod):
num_trt_subgraphs = sum(
[1 if gv.name_hint == "tensorrt_0" else 0 for gv in mod.get_global_vars()]
)
assert num_trt_subgraphs == 1

def compile_and_run(mod, params, i_data, mode="vm", use_trt=True):
assert mode in ["graph", "vm"]

if use_trt:
mod, config = tensorrt.partition_for_tensorrt(mod, params)
check_trt_used(mod)
with tvm.transform.PassContext(
opt_level=3, config={"relay.ext.tensorrt.options": config}
):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
else:
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()

res = func(i_data, **params) if run_module else None
return res

dtype = "float32"
input_shape = (1, 3, 224, 224)
i_data = np.random.uniform(-1, 1, input_shape).astype(dtype)
block = get_model(model, pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)

result_dict = dict()
for mode in ["vm", "graph"]:
for use_trt in [True, False]:
result_key = mode + ("_trt" if use_trt else "")
result_dict[result_key] = compile_and_run(
mod, params, i_data, mode=mode, use_trt=use_trt
)

if run_module:
assert_result_dict_holds(result_dict)


def test_tensorrt_simple(run_module):
for dtype in SUPPORTED_DTYPES:
xshape = (1, 3, 2, 2)
Expand Down Expand Up @@ -278,113 +229,6 @@ def test_tensorrt_not_compatible(run_module):
results = func(x_data)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_tensorrt_serialize_graph_executor(run_module):
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model

data_shape = (1, 3, 224, 224)
data_type = "float32"
i_data = np.random.uniform(0, 1, data_shape).astype(data_type)
block = get_model("resnet18_v1", pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={"data": data_shape}, dtype=data_type)
mod, config = tensorrt.partition_for_tensorrt(mod)
tmpdir = utils.tempdir()

def compile_graph(mod, params):
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
graph, lib, params = relay.build(mod, params=params, target="cuda")
params = runtime.save_param_dict(params)
return graph, lib, params

def run_graph(graph, lib, params):
mod_ = graph_executor.create(graph, lib, device=tvm.cuda(0))
mod_.load_params(params)
mod_.run(data=i_data)
res = mod_.get_output(0)
return res

def save_graph(graph, lib, params):
# Serialize
with open(tmpdir.relpath("compiled.json"), "w") as f_graph_json:
f_graph_json.write(graph)
with open(tmpdir.relpath("compiled.params"), "wb") as f_params:
f_params.write(params)
lib.export_library(tmpdir.relpath("compiled.so"))

def load_graph():
# Deserialize
with open(tmpdir.relpath("compiled.json"), "r") as f_graph_json:
graph = f_graph_json.read()
with open(tmpdir.relpath("compiled.params"), "rb") as f_params:
params = bytearray(f_params.read())
lib = tvm.runtime.load_module(tmpdir.relpath("compiled.so"))
return graph, lib, params

# Test serialization with graph executor
graph, lib, graph_params = compile_graph(mod, params)
save_graph(graph, lib, graph_params)
loaded_graph, loaded_lib, loaded_params = load_graph()

if run_module:
result_dict = dict()
result_dict["graph"] = run_graph(graph, lib, graph_params)
result_dict["graph_ref"] = run_graph(loaded_graph, loaded_lib, loaded_params)
assert_result_dict_holds(result_dict)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_tensorrt_serialize_vm(run_module):
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model

data_shape = (1, 3, 224, 224)
data_type = "float32"
i_data = np.random.uniform(0, 1, data_shape).astype(data_type)
block = get_model("resnet18_v1", pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={"data": data_shape}, dtype=data_type)
mod, config = tensorrt.partition_for_tensorrt(mod)
tmpdir = utils.tempdir()

def compile_vm(mod, params):
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
vm_exec = relay.vm.compile(mod, target="cuda", params=params)
code, lib = vm_exec.save()
return code, lib

def run_vm(code, lib):
vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
vm = VirtualMachine(vm_exec, tvm.cuda(0))
result = vm.invoke("main", data=i_data)
return result

def save_vm(code, lib):
# save and load the code and lib file.
lib.export_library(tmpdir.relpath("path_lib.so"))
with open(tmpdir.relpath("path_code.ro"), "wb") as fo:
fo.write(code)

def load_vm():
lib = tvm.runtime.load_module(tmpdir.relpath("path_lib.so"))
code = bytearray(open(tmpdir.relpath("path_code.ro"), "rb").read())
return lib, code

# Test serialization with VM
code_vm, lib_vm = compile_vm(mod, params)
save_vm(code_vm, lib_vm)
loaded_lib_vm, loaded_code_vm = load_vm()

if run_module:
result_dict = dict()
result_dict["vm"] = run_vm(code_vm, lib_vm)
result_dict["vm_ref"] = run_vm(loaded_code_vm, loaded_lib_vm)
assert_result_dict_holds(result_dict)


def test_conv1d(run_module):
def get_graph(
x_shape=((1, 3, 224)),
Expand Down Expand Up @@ -1302,62 +1146,6 @@ def get_graph(
)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_alexnet(run_module):
run_and_verify_model("alexnet", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_resnet18_v1(run_module):
run_and_verify_model("resnet18_v1", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_resnet18_v2(run_module):
run_and_verify_model("resnet18_v2", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_squeezenet(run_module):
run_and_verify_model("squeezenet1.0", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_mobilenet(run_module):
run_and_verify_model("mobilenet0.25", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_mobilenet_v2(run_module):
run_and_verify_model("mobilenetv2_0.25", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_vgg11(run_module):
run_and_verify_model("vgg11", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
def test_densenet121(run_module):
run_and_verify_model("densenet121", run_module)


@pytest.mark.xfail(
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
)
Expand Down