Skip to content

MultiLabelSequenceClassificationExplainer potentially bugged. #107

@rowanvanhoof

Description

@rowanvanhoof

Before I describe the issue, I must state that data science is complicated and I could be misunderstanding your code. I see that the multi label explainer calls a binomial explainer for the number of labels you have. I recognize that some implementations of multilabel models use the approach of simply training n binomial classifiers and weighing each of their predictions. If that is the thought process when creating this package than I believe that I my issue below is a misunderstanding of how the code should be used. In that case, I wasn't able to trace through the code and figure out if there is any further processing of the n outputs from the softmax functions.

I have looked into the source code here and it looks like the multi label explainer is actually a multi class explainer. Multi label problems involve data where one instance can have 0, 1, or more labels (i.e. a photo of dog can be labeled as a dog, an animal, and a mammal), while multiclass problems involve data where one instance can have a single label, which can be many different labels. The primary difference is that multiclass has mutually exclusive labels which requires a different activation function than multilabel.

The multilabel explainer uses a softmax function to process the logits from the classifier. This is correct for multiclass, because softmax creates a probability distribution from the logits, meaning that only one label can be highly confident (probability over .5). Multilabel problems require a different activation function, like the sigmoid function, which converts logits into probabilities that do not necessarily sum to 1, so you can have several labels that the model is highly confident in.

Additionally, I noticed that for the multi label explainer, the prediction is made once for each label, however multilabel classifiers give the logits for all labels in one prediction. Calling predict with the same input 6 times in a row produces 6 identical sets of logits, so I am wondering if this is redundant. I am aware that some other explainers like eli5 have to do this to be able to accurately compute the attributions for several classes, but wasn't sure if the repetition here is for the same reason.

I modified the explainer to run 6 times using the sigmoid function built in to pytorch instead of the softmax function, and it works as expected. If this is a bug I can submit a pull request with the changes that I have made.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions