Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def relay_to_relax(
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
) -> tvm.IRModule:
"""Change relay IRModule to relax MSCGraph.

Expand All @@ -239,6 +240,8 @@ def relay_to_relax(
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.
build_folder: MSCDirectory
The folder for saving scripts and datas.

Returns
-------
Expand All @@ -254,4 +257,4 @@ def relay_to_relax(
opt_config=opt_config,
)

return to_relax(graph, weights, codegen_config={"from_relay": True})
return to_relax(graph, weights, codegen_config={"from_relay": True}, build_folder=build_folder)
48 changes: 47 additions & 1 deletion python/tvm/contrib/msc/core/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import os
import shutil
import json
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Any, Tuple
import numpy as np

import tvm
from .arguments import load_dict
from .info import cast_array, is_array
from .namespace import MSCFramework


def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], style="dict") -> Any:
Expand Down Expand Up @@ -64,6 +65,51 @@ def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], styl
raise TypeError("Unexpected style " + str(style))


def random_data(
info: Union[List, Tuple, dict],
framework: str = MSCFramework.MSC,
device: str = "cpu",
max_val: int = None,
) -> Any:
"""Create random data from info

Parameters
----------
info: list| tuple| dict
The data info.
framework: str
The framework.
device: str
The device.
"""

if isinstance(info, (tuple, list)):
if len(info) == 1:
info = {"name": "data", "shape": info[0], "dtype": "float32"}
elif len(info) == 2:
info = {"name": "data", "shape": info[0], "dtype": info[1]}
elif len(info) == 3:
info = {"name": info[0], "shape": info[1], "dtype": info[2]}
else:
raise Exception("Unexpected info " + str(info))
assert isinstance(info, dict) and all(
key in info for key in ["shape", "dtype"]
), "shape and dtype should be given to create randome data"
if info["dtype"] in ("int32", "int64"):
if max_val is None:
data = np.zeros(info["shape"]).astype(info["dtype"])
else:
data = np.random.randint(0, high=max_val, size=info["shape"]).astype(info["dtype"])
elif info["dtype"] == "bool":
data = np.random.rand(*info["shape"]).astype("float32")
data = np.where(data >= 0.5, True, False)
else:
data = np.random.rand(*info["shape"]).astype(info["dtype"])
if max_val is not None:
data *= max_val
return cast_array(data, framework, device=device)


class BaseDataLoader(object):
"""Basic dataset loader for MSC

Expand Down
18 changes: 15 additions & 3 deletions python/tvm/contrib/msc/framework/torch/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.contrib.msc.core.ir.graph import MSCGraph
from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs
from tvm.contrib.msc.core.codegen import relay_to_relax
from tvm.contrib.msc.core import utils as msc_utils


def set_weight_alias(graph: MSCGraph) -> MSCGraph:
Expand Down Expand Up @@ -70,6 +71,7 @@ def from_torch(
opt_config: Optional[Dict[str, str]] = None,
as_msc: bool = True,
custom_convert_map: dict = None,
build_folder: msc_utils.MSCDirectory = None,
) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.

Expand All @@ -93,6 +95,8 @@ def from_torch(
Set to to return msc graph, otherwise relax mod
custom_convert_map: dict
The convert map for plugin
build_folder: MSCDirectory
The folder for saving scripts and datas.

Returns
-------
Expand All @@ -102,9 +106,15 @@ def from_torch(
The weights from the IRModule.
"""

# try to symbolic_trace
if via_relax:
input_info = normalize_inputs(input_info)
graph_model, params = torch.fx.symbolic_trace(model), None
try:
graph_model = torch.fx.symbolic_trace(model)
except: # pylint: disable=bare-except
via_relax = False

if via_relax:
input_info, params = normalize_inputs(input_info), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)
else:
Expand All @@ -122,7 +132,9 @@ def from_torch(
relay_mod, params = tvm.relay.frontend.from_pytorch(
scripted_model, shape_list, custom_convert_map=custom_convert_map
)
relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config)
relax_mod = relay_to_relax(
relay_mod, params, trans_config, build_config, opt_config, build_folder=build_folder
)
if not as_msc:
return relax_mod, params
graph, weights = from_relax(relax_mod, trans_config=trans_config, build_config=build_config)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/contrib/msc/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import json
from typing import Any, Union, List, Tuple
import traceback
import numpy as np

from tvm.contrib.msc.core.tools import get_tool_cls, BaseTool
from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey
Expand Down Expand Up @@ -678,7 +677,7 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any:
def get_random():
def _to_data(inp):
shape = [1 if isinstance(d, str) else d for d in inp[1]]
return np.random.rand(*shape).astype(inp[2])
return msc_utils.random_data([shape, inp[2]])

for _ in range(max_batch):
yield {i[0]: _to_data(i) for i in self._config["inputs"]}
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,20 @@ def _reshape(self, node: fx.Node) -> relax.Var:
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.reshape(x, dims))

def _scatter(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if len(node.args) == 1:
dim = node.kwargs["dim"]
index = self.env[node.kwargs["index"]]
src = self.env[node.kwargs["src"]]
elif len(node.args) == 4:
dim = node.args[1]
index = self.env[node.args[2]]
src = self.env[node.args[3]]
else:
raise Exception("Unexpected args " + str(node.args))
return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim))

def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
Expand All @@ -801,6 +815,24 @@ def _squeeze(self, node: fx.Node) -> relax.Var:
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
return self.block_builder.emit(relax.op.squeeze(x, dim))

def _stack(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
in_args = args[0]
assert all(
a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] for a in in_args[1:]
), "Expect all dim at {} to be the same, get {}".format(
axis, [a.struct_info.shape for a in args]
)
cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
s_shape = []
for idx, s in enumerate(cat.struct_info.shape):
if idx == axis:
s_shape.extend([len(in_args), in_args[0].struct_info.shape[axis]])
else:
s_shape.append(s)
return self.block_builder.emit(relax.op.reshape(cat, s_shape))

def _tile(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,11 @@ def create_convert_map(
"permute": self._permute,
"repeat": self._repeat,
"reshape": self._reshape,
"scatter": self._scatter,
"size": self._size,
"split": self._split,
"squeeze": self._squeeze,
"stack": self._stack,
"tile": self._tile,
"transpose": self._transpose,
"unsqueeze": lambda node: self.block_builder.emit(
Expand Down
56 changes: 53 additions & 3 deletions src/contrib/msc/framework/torch/torch_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,28 @@ class TorchConstantCodeGen : public TorchOpCode {

protected:
void CodeGenInit() final {
const auto& dtype = node()->OutputAt(0)->DTypeName();
const auto& ref_name = StringUtils::Replace(node()->name, ".", "_");
if (node()->HasAttr("scalar")) {
if (node()->OutputAt(0)->DTypeName() == "int32") {
if (dtype == "int32") {
stack_.assign(module_ref(), node()->GetTypeAttr<int>("scalar"));
} else if (node()->OutputAt(0)->DTypeName() == "int64") {
} else if (dtype == "int64") {
stack_.assign(module_ref(), node()->GetTypeAttr<int64_t>("scalar"));
} else if (node()->OutputAt(0)->DTypeName() == "float32") {
} else if (dtype == "float32") {
stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
}
} else if (dtype == "int32") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
.inplace_start("torch.IntTensor")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
.inplace_end();
} else if (dtype == "int64") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
.inplace_start("torch.LongTensor")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
.inplace_end();
} else {
stack_.func_call("torch.Tensor", "data")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
Expand Down Expand Up @@ -565,6 +579,39 @@ class TorchSimpleCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchSimpleCodeGen);
};

class TorchScatterElementsCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchScatterElementsCodeGen)

protected:
void CodeGenForward() final {
if (node()->InputAt(1)->DTypeName() == "int32") {
stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
}
stack_.op_call()
.op_input_arg()
.op_arg<int>("axis", "dim")
.op_input_arg(1, "index")
.op_input_arg(2, "src");
}
};

class TorchScatterNDCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchScatterNDCodeGen)

protected:
void CodeGenForward() final {
if (node()->InputAt(1)->DTypeName() == "int32") {
stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
}
// relax add extra dim for indices
if (node()->InputAt(1)->Ndim() == node()->OutputAt(0)->Ndim()) {
stack_.func_call("squeeze", IdxInput(1), IdxInput(1)).call_arg(-1);
}
stack_.assign(DocUtils::ToIndex(IdxInput(0), IdxInput(1)), IdxInput(2))
.assign(IdxNode(), IdxInput(0));
}
};

class TorchSplitCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchSplitCodeGen)

Expand Down Expand Up @@ -719,6 +766,9 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
map->emplace("permute_dims", std::make_shared<TorchPermuteDimsCodeGen>("", "torch.permute"));
map->emplace("repeat", std::make_shared<TorchRepeatCodeGen>("", "repeat"));
map->emplace("reshape", std::make_shared<TorchReshapeCodeGen>("", "torch.reshape"));
map->emplace("scatter_elements",
std::make_shared<TorchScatterElementsCodeGen>("", "torch.scatter"));
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("", "torch.split"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", ""));

Expand Down
46 changes: 46 additions & 0 deletions src/contrib/msc/framework/tvm/relax_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,34 @@ class RelaxReshapeCodeGen : public RelaxOpCode {
}
};

class RelaxScatterElementsCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxScatterElementsCodeGen)

protected:
void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_arg<int>("axis"); }
};

class RelaxScatterNDCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxScatterNDCodeGen)

protected:
void CodeGenBuild() final {
if (config()->from_relay) {
size_t ndim = node()->InputAt(1)->Ndim();
std::vector<size_t> axes;
axes.push_back(ndim - 1);
for (size_t i = 0; i < ndim - 1; i++) {
axes.push_back(i);
}
stack_.func_call("relax.op.permute_dims", IdxInput(1))
.call_arg(IdxInput(1))
.call_arg(DocUtils::ToList(axes));
BuilderEmit(IdxInput(1), "permute_" + std::to_string(node()->index));
}
stack_.op_call().op_inputs_arg(false).op_str_arg("mode", "reduction");
}
};

class RelaxResize2dCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxResize2dCodeGen)

Expand Down Expand Up @@ -626,6 +654,20 @@ class RelaxSplitCodeGen : public RelaxOpCode {
}
};

class RelaxStackCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxStackCodeGen)

protected:
void CodeGenBuild() final {
stack_.op_call().op_inputs_arg().op_arg<int>("axis");
BuilderEmit(IdxNode(), "cat_" + std::to_string(node()->index));
const auto& out_shape = GetPrims(node()->OutputAt(0));
stack_.func_call("relax.op.reshape", IdxNode())
.call_arg(IdxNode())
.call_arg(DocUtils::ToList(out_shape), "shape");
}
};

class RelaxTakeCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen)

Expand Down Expand Up @@ -763,7 +805,11 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<RelaxOpCode>>>
map->emplace("permute_dims", std::make_shared<RelaxPermuteDimsCodeGen>("relax.op.permute_dims"));
map->emplace("repeat", std::make_shared<RelaxRepeatCodeGen>("relax.op.repeat"));
map->emplace("reshape", std::make_shared<RelaxReshapeCodeGen>("relax.op.reshape"));
map->emplace("scatter_elements",
std::make_shared<RelaxScatterElementsCodeGen>("relax.op.scatter_elements"));
map->emplace("scatter_nd", std::make_shared<RelaxScatterNDCodeGen>("relax.op.scatter_nd"));
map->emplace("split", std::make_shared<RelaxSplitCodeGen>("relax.op.split"));
map->emplace("stack", std::make_shared<RelaxStackCodeGen>("relax.op.concat"));
map->emplace("strided_slice",
std::make_shared<RelaxStridedSliceCodeGen>("relax.op.strided_slice"));
map->emplace("take", std::make_shared<RelaxTakeCodeGen>("relax.op.take"));
Expand Down
Loading