Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,7 @@ test.html
api.py
.vscode
notebooks/explore_captum.ipynb
*.html
*.html

# Pycharm
.idea/
40 changes: 38 additions & 2 deletions test/test_sequence_classification_explainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from IPython.core.display import HTML

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from transformers_interpret.errors import (
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_sequence_classification_explainer_init_bert():
assert seq_explainer.attribution_type == "lig"
assert seq_explainer.label2id == BERT_MODEL.config.label2id
assert seq_explainer.id2label == BERT_MODEL.config.id2label
assert seq_explainer.attributions == None
assert seq_explainer.attributions is None


def test_sequence_classification_explainer_init_attribution_type_error():
Expand All @@ -49,6 +49,26 @@ def test_sequence_classification_explainer_init_attribution_type_error():
)


def test_sequence_classification_explainer_init_with_custom_labels():
labels = ["label_1", "label_2"]
seq_explainer = SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=labels
)
assert len(labels) == len(seq_explainer.id2label)
assert len(labels) == len(seq_explainer.label2id)
for (k1, v1), (k2, v2) in zip(
seq_explainer.id2label.items(), seq_explainer.label2id.items()
):
assert v1 in labels and k2 in labels


def test_sequence_classification_explainer_init_custom_labels_size_error():
with pytest.raises(ValueError):
SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["few_labels"]
)


def test_sequence_classification_explainer_attribution_type_unset_before_run():
explainer_string = "I love you , I like you"

Expand Down Expand Up @@ -185,6 +205,22 @@ def test_sequence_classification_explain_on_cls_name():
assert seq_explainer.predicted_class_name == "POSITIVE"


def test_sequence_classification_explain_on_cls_name_with_custom_labels():
explainer_string = "I love you , I like you"
seq_explainer = SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"]
)
seq_explainer._run(explainer_string, class_name="sad")
assert seq_explainer.predicted_class_index == 1
assert seq_explainer.predicted_class_index != seq_explainer.selected_index
assert (
seq_explainer.predicted_class_name
!= seq_explainer.id2label[seq_explainer.selected_index]
)
assert seq_explainer.predicted_class_name != "sad"
assert seq_explainer.predicted_class_name == "happy"


def test_sequence_classification_explain_on_cls_name_not_in_dict():
explainer_string = "I love you , I like you"
seq_explainer = SequenceClassificationExplainer(
Expand Down
4 changes: 3 additions & 1 deletion transformers_interpret/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def _set_available_embedding_types(self):
if hasattr(self.model_embeddings, "position_embeddings"):
self.position_embeddings = self.model_embeddings.position_embeddings
if hasattr(self.model_embeddings, "token_type_embeddings"):
self.token_type_embeddings = self.model_embeddings.token_type_embeddings
self.token_type_embeddings = (
self.model_embeddings.token_type_embeddings
)

def __str__(self):
s = f"{self.__class__.__name__}("
Expand Down
35 changes: 30 additions & 5 deletions transformers_interpret/explainers/sequence_classification.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import warnings
from enum import Enum
from typing import Union
from typing import Dict, List, Optional, Tuple, Union

import captum
import torch
from captum.attr import visualization as viz
from torch.nn.modules.sparse import Embedding
Expand Down Expand Up @@ -50,12 +48,16 @@ def __init__(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
attribution_type: str = "lig",
custom_labels: Optional[List[str]] = None,
):
"""
Args:
model (PreTrainedModel): Pretrained huggingface Sequence Classification model.
tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer
attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig".
custom_labels (List[str], optional): Applies custom labels to label2id and id2label configs.
Labels must be same length as the base model configs' labels.
Labels and ids are applied index-wise. Defaults to None.

Raises:
AttributionTypeNotSupportedError:
Expand All @@ -68,14 +70,37 @@ def __init__(
)
self.attribution_type = attribution_type

self.label2id = model.config.label2id
self.id2label = model.config.id2label
if custom_labels is not None:
if len(custom_labels) != len(model.config.label2id):
raise ValueError(
f"""`custom_labels` size '{len(custom_labels)}' should match pretrained model's label2id size
'{len(model.config.label2id)}'"""
)

self.id2label, self.label2id = self._get_id2label_and_label2id_dict(
custom_labels
)
else:
self.label2id = model.config.label2id
self.id2label = model.config.id2label

self.attributions: Union[None, LIGAttributions] = None
self.input_ids: torch.Tensor = torch.Tensor()

self._single_node_output = False

@staticmethod
def _get_id2label_and_label2id_dict(
labels: List[str],
) -> Tuple[Dict[int, str], Dict[str, int]]:
id2label: Dict[int, str] = dict()
label2id: Dict[str, int] = dict()
for idx, label in enumerate(labels):
id2label[idx] = label
label2id[label] = idx

return id2label, label2id

def encode(self, text: str = None) -> list:
return self.tokenizer.encode(text, add_special_tokens=False)

Expand Down