Compute attributions w.r.t the predicted logit, not the predicted loss#4882
Compute attributions w.r.t the predicted logit, not the predicted loss#4882sarahwie wants to merge 1 commit intoallenai:mainfrom
Conversation
9c43e97 to
ef337d2
Compare
|
Realized this breaks Hotflip, since that relies on the loss. Also, I am not sure how the input reduction method is intended to be calculated (the paper just says "outputs"), but it will change with this, too. |
| ) | ||
|
|
||
| loss = outputs["loss"] | ||
| predicted_logit = outputs["logits"].squeeze(0)[int(torch.argmax(outputs["probs"]))] |
There was a problem hiding this comment.
The trouble with doing it this way is that it hard-codes assumptions about the model's outputs which may not be true. The test failure you're getting is because of this. This method has to be generic enough to work for any model. This is ok when we query the loss key, because that key is already required by the Trainer. Nothing else is guaranteed to be in the output, so we can't hard-code anything else.
Maybe a better way of accomplishing what you want is to allow the caller to specify the output key, with a default value of "loss". Then it would be the model's responsibility make sure that the value in the key is a single number on which we can call .backward(). E.g., you could imagine adding a target_logit key in your model class, and then use that key when calling get_gradients().
We could get by with less model modification if we add a second flag that says whether to take an argmax of the values in that key, but that gets a bit messy, because then you're always getting gradients of the model's prediction, completely ignoring whatever label was given in the input instance. This breaks a lot of assumptions in other methods in the code (which I think is what you were referring to when you said this breaks hotflip), so I don't really like this option.
There was a problem hiding this comment.
Thanks for the feedback! I agree that using a key is straightforward. I'll refactor.
|
Is this still an active project? Can we help in any way? |
Compute gradient attribution with respect to the predicted class' logit to avoid a dependency of the gradient on the loss' distance to 0, which causes a 0 gradient.
See for justification: