Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 223 additions & 61 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,61 +740,140 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var:
bias = args[2] if len(args) > 2 else None
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))

def _conv1d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]

def _conv1d_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv1d = self.block_builder.emit(
relax.op.nn.conv1d(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCW",
kernel_layout="OIW",
out_dtype="float32",
)
)

if module.bias is None:
if bias is None:
return conv1d

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1))

return self.block_builder.emit(relax.op.add(conv1d, bias))

def _conv3d(self, node: fx.node.Node) -> relax.Var:
def _conv1d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

conv3d = self.block_builder.emit(
relax.op.nn.conv3d(
return self._conv1d_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv1d_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv1d_transpose_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv1d_transpose = self.block_builder.emit(
relax.op.nn.conv1d_transpose(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCDHW",
kernel_layout="OIDHW",
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCW",
kernel_layout="OIW",
out_dtype="float32",
)
)

if module.bias is None:
return conv3d
if bias is None:
return conv1d_transpose

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
bias = relax.op.reshape(bias, (1, -1, 1))
return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))

return self.block_builder.emit(relax.op.add(conv3d, bias))
def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

return self._conv1d_transpose_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv1d_transpose_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv2d_impl(
self,
Expand Down Expand Up @@ -826,71 +905,150 @@ def _conv2d_impl(
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d, bias))

def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
def _conv2d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

conv1d_transpose = self.block_builder.emit(
relax.op.nn.conv1d_transpose(
return self._conv2d_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv2d_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv2d_transpose_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
) -> relax.Var:
conv2d_transpose = self.block_builder.emit(
relax.op.nn.conv2d_transpose(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCW",
kernel_layout="OIW",
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="float32",
)
)

if module.bias is None:
return conv1d_transpose
if bias is None:
return conv2d_transpose

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1))

return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))

def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

conv2d_transpose = self.block_builder.emit(
relax.op.nn.conv2d_transpose(
return self._conv2d_transpose_impl(
x,
weight,
bias=bias,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv2d_transpose_impl(
x,
weight,
bias=bias,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)

def _conv3d_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
):
conv3d = self.block_builder.emit(
relax.op.nn.conv3d(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCHW",
kernel_layout="OIHW",
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_dtype="float32",
)
)

if module.bias is None:
return conv2d_transpose

bias = self.params[module.bias]
if bias is None:
return conv3d
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1, 1))

return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
return self.block_builder.emit(relax.op.add(conv3d, bias))

def _conv2d(self, node: fx.node.Node) -> relax.Var:
def _conv3d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None
if module.bias is not None:
bias = self.params[module.bias]

return self._conv2d_impl(
return self._conv3d_impl(
x,
weight,
bias=bias,
Expand All @@ -900,7 +1058,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var:
groups=module.groups,
)

def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
def _conv3d_functional(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
Expand All @@ -909,7 +1067,7 @@ def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
padding = args[4] if len(args) > 4 else 0
dilation = args[5] if len(args) > 5 else 1
groups = args[6] if len(args) > 6 else 1
return self._conv2d_impl(
return self._conv3d_impl(
x,
weight,
bias=bias,
Expand Down Expand Up @@ -1482,7 +1640,11 @@ def create_convert_map(self):
"type": self._type,
"astype": self._type,
"matmul": self._matmul,
"conv1d": self._conv1d_functional,
"conv_transpose1d": self._conv1d_transpose_functional,
"conv2d": self._conv2d_functional,
"conv_transpose2d": self._conv2d_transpose_functional,
"conv3d": self._conv3d_functional,
"linear": self._linear_functional,
"addmm": self._addmm,
"baddbmm": self._baddbmm,
Expand Down
Loading