Skip to content

Commit 51d704d

Browse files
save
1 parent 57cd27f commit 51d704d

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,12 @@ def dense_grad(orig, grad):
290290
collapse_sum_like(data * transpose(grad), weight)]
291291

292292

293+
@register_gradient("reshape")
294+
def reshape_grad(orig, grad):
295+
"""Gradient of reshape"""
296+
return [reshape_like(grad, orig.args[0])]
297+
298+
293299
@register_gradient("nn.batch_flatten")
294300
def batch_flatten_grad(orig, grad):
295301
"""Returns grad reshaped to data dims"""

tests/python/relay/test_op_grad_level3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import numpy as np
18+
import pytest
1819

1920
import tvm
2021
from tvm import relay
@@ -58,6 +59,4 @@ def test_negative_grad():
5859

5960

6061
if __name__ == "__main__":
61-
test_clip()
62-
test_transpose_grad()
63-
test_negative_grad()
62+
pytest.main()

tests/python/relay/test_op_level3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tvm
2222
from tvm import relay
2323
from tvm.relay import create_executor, transform
24-
from tvm.relay.testing import ctx_list
24+
from tvm.relay.testing import ctx_list, check_grad
2525

2626
def run_infer_type(expr):
2727
mod = relay.Module.from_expr(expr)
@@ -247,6 +247,7 @@ def verify_reshape(shape, newshape, oshape):
247247
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
248248

249249
func = relay.Function([x], z)
250+
check_grad(func)
250251
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
251252
ref_res = np.reshape(x_data, oshape)
252253
for target, ctx in ctx_list():

0 commit comments

Comments
 (0)