[Relay/TOPI][OP] Add arange op in Relay and TOPI#2621
Conversation
include/tvm/relay/attrs/transform.h
Outdated
| .describe("Start of interval. The interval includes this value."); | ||
| TVM_ATTR_FIELD(stop) | ||
| .describe("Stop of interval. The interval does not include this value."); | ||
| TVM_ATTR_FIELD(start).set_default(make_const(Int(32), 1)) |
| relay.arange(5) = [0, 1, 2, 3, 4] | ||
| relay.arange(1, 5) = [1, 2, 3, 4] | ||
| relay.arange(1, 5, 1.5) = [1, 2.5, 4] | ||
| """ |
There was a problem hiding this comment.
should we consider cases like start > stop and step <= 0, here? I think we probably need to at least warning or raise exceptions for step == 0
There was a problem hiding this comment.
Added the sanity check in the new commit
| CHECK_EQ(types.size(), 1); | ||
| const ArangeAttrs* param = attrs.as<ArangeAttrs>(); | ||
| IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
| tvm::cast(tvm::Float(32), param->stop - param->start) / param->step)); |
There was a problem hiding this comment.
step is not necessary to be constant during the compilation time. So probably we should rely on IR to capture this?
| std::string name = "tensor", | ||
| std::string tag = kInjective) { | ||
| Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
| tvm::cast(tvm::Float(32), stop - start) / step)); |
There was a problem hiding this comment.
not sure if we need to check if step == 0, probably it is enough if we checked if before
There was a problem hiding this comment.
divide by 0 should be captured by IR when step is constant.
https://github.com/dmlc/tvm/blob/master/src/lang/ir_operator.cc#L202
|
Hey @yzhliu @zhreshold, could you help review this PR? |
| verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3)) | ||
|
|
||
|
|
||
| def test_arange(): |
There was a problem hiding this comment.
can we also have a relay frontend test from mxnet arange
There was a problem hiding this comment.
Added in the new commit.
python/tvm/relay/op/transform.py
Outdated
| return _make.full_like(data, fill_value) | ||
|
|
||
|
|
||
| def arange(stop, start=None, step=1, dtype="float32"): |
There was a problem hiding this comment.
the doc does not match the arg trick though
There was a problem hiding this comment.
Made a note in docs in the new commit
src/relay/op/tensor/transform.cc
Outdated
| IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
| tvm::cast(tvm::Float(32), param->stop - param->start) / param->step)); | ||
| if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) { | ||
| CHECK_GT(val->value, 0) << "Invalid arange inputs"; |
There was a problem hiding this comment.
suggest to also print related params
There was a problem hiding this comment.
Fixed in the new commit
|
Thanks all. This is merged. |
* Add arange op * Update docs * Fix bug * add sanity check in relay and mxnet frontend mapping * lint * nits * pylint * don't allow empty output from arange * Remove empty test for arange * Fix bug and update doc
* Add arange op * Update docs * Fix bug * add sanity check in relay and mxnet frontend mapping * lint * nits * pylint * don't allow empty output from arange * Remove empty test for arange * Fix bug and update doc
* Add arange op * Update docs * Fix bug * add sanity check in relay and mxnet frontend mapping * lint * nits * pylint * don't allow empty output from arange * Remove empty test for arange * Fix bug and update doc
Currently I put start, stop, step in Relay attributes since it is required to infer the output shape. Later if Relay supports unknown dimension like
Any, we can move them into inputs of arange op instead of attributes.This PR relies on #2615.