[Relay][Training] Add gradient for Crossentropy#3925
Conversation
python/tvm/relay/op/nn/_nn.py
Outdated
| reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
|
|
There was a problem hiding this comment.
nit: remove extra blank line (only two are needed)
|
@vinx13 I had address your comment. can you review again? |
| pack_dtype, out_dtype, unipolar) | ||
|
|
||
|
|
||
| def cross_entropy(predictions, targets): |
There was a problem hiding this comment.
This should have a docstring. You should also mention that this is cross-entropy without softmax, as many frameworks equate cross-entropy to cross-entropy from logits
There was a problem hiding this comment.
@MarisaKirisame can you react on @SWu 's comment (also put it in REGISTER_RELAY_OP section)
src/relay/op/nn/nn.cc
Outdated
|
|
||
|
|
||
| RELAY_REGISTER_OP("nn.cross_entropy") | ||
| .describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE) |
There was a problem hiding this comment.
| .describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE) | |
| .describe(R"code(Computes cross entropy given predictions and targets.)code" TVM_ADD_FILELINE) |
src/relay/op/nn/nn.cc
Outdated
| << "y shape=" << y->shape; | ||
| CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) | ||
| << "CrossEntropy: shapes of x and y is inconsistent, " | ||
| << "x shape=, " << x->shape |
There was a problem hiding this comment.
| << "x shape=, " << x->shape | |
| << "x shape = " << x->shape << ", " |
There was a problem hiding this comment.
and can this be done for all of the above instances?
python/tvm/relay/op/nn/_nn.py
Outdated
| @reg.register_compute("nn.cross_entropy") | ||
| def compute_cross_entropy(attrs, inputs, out_dtype, target): | ||
| x, y = inputs | ||
| return [-topi.sum(topi.log(x) * y / x.shape[0])] |
There was a problem hiding this comment.
Would it be more efficient and numerically stable to divide by the batch size after the sum?
|
@MarisaKirisame The schedule should be injective, can you check if the CUDA schedule are properly called? |
|
@vinx13 how can I do that? I am not really familiar with tvm low level internal. |
|
@MarisaKirisame I will take a look |
|
@MarisaKirisame I guess this is caused by the use of reference. It makes fusion and scheduling difficult. But I didn't reproduce the error on master, can you try rebasing? |
d112e48 to
9f3c850
Compare
python/tvm/relay/op/nn/_nn.py
Outdated
| reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
| reg.register_schedule("nn.cross_entropy", schedule_injective) |
There was a problem hiding this comment.
the schedule actually should be schedule_reduce (in relay.op._reduce)
| from tvm.relay.testing import check_grad | ||
|
|
||
|
|
||
| def test_crossentropy_grad(): |
|
@MarisaKirisame see my comment, the schedule should be reduce |
|
ping @MarisaKirisame |
9f3c850 to
3517d33
Compare
|
@vinx13 sorry, I was pushing training on a private branch. I had addressed the issues. |
3517d33 to
9680232
Compare
|
@MarisaKirisame might be a flaky case, can you restart the ci? |
|
@vinx13 I will restart it right now. Just FYI I also got the same error last time. |
71986be to
e884dfc
Compare
|
@MarisaKirisame you can try increasing rtol of the failing test |
|
@vinx13 it now work. |
|
@vinx13 I had acted on the comment. |
* save save redo max test save address comment fix * address comment * increase rtol * address review comment
* save save redo max test save address comment fix * address comment * increase rtol * address review comment
* master: (21 commits) [Fix][VM] Fix VM invoke with set_params (apache#4079) [QNN] Refactor fixed point multiplication in requantize (apache#4073) Fix match case in Python-side expr functor (apache#4037) Hide symbols from dependent libraries if HIDE_PRIVATE_SYMBOLS is ON. (apache#4041) Add gradient for log-softmax (apache#4069) [DOC] Fix typos in tutorials (apache#4066) dicrease the complexity of CalcDep from exponential to linear (apache#4053) [Relay][AlterOp] Minor refactor. (apache#4064) [Relay][AlterOp] Improving support for broadcast layout alteration. (apache#4040) Add parses support for zeros_like tflite operator (apache#4042) [Bugfix][TF] reset graph after getting tag of savedmodel (apache#4055) [Relay][VM] Add more passes to VMCompiler (apache#4058) [Relay][VM] Add autotvm context when compile (apache#4062) [Bugfix] Fix target host for vm compiler (apache#4057) [Relay][Training] Add gradient for Crossentropy (apache#3925) [llvm] switch to use Align for llvm trunk (apache#4051) [Relay][TopHub] Add switch to disable TopHub download (apache#4015) [Relay][Op] Add instance norm op (apache#4004) [QNN][Relay] Calling Dialect passes from inside Relay Build API. (apache#3971) [RELAY/PASS] Fix the extent for the post_stmt in the loop partition (apache#3734) ...
@vinx13 @junrushao1994 @SWu can you guys help review?