Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions farm/data_handler/input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
return [feature_dict]


def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks_mid,
def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks_mid, sp_toks_end,
answer_type_list=None, max_answers=6):
""" Prepares data for processing by the model. Supports cases where there are
multiple answers for the one question/document pair. max_answers is by default set to 6 since
Expand Down Expand Up @@ -472,6 +472,15 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks
# tokens are attended to.
padding_mask = [1] * len(input_ids)

# The passage mask has 1 for tokens that are valid start or ends for QA spans.
# 0s are assigned to question tokens, mid special tokens, end special tokens and padding
# Note that start special tokens are assigned 1 since they can be chosen for a no_answer prediction
span_mask = [1] * sp_toks_start
span_mask += [0] * question_len_t
span_mask += [0] * sp_toks_mid
span_mask += [1] * passage_len_t
span_mask += [0] * sp_toks_end

# Pad up to the sequence length. For certain models, the pad token id is not 0 (e.g. Roberta where it is 1)
pad_idx = tokenizer.pad_token_id
padding = [pad_idx] * (max_seq_len - len(input_ids))
Expand All @@ -481,6 +490,7 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks
padding_mask += zero_padding
segment_ids += zero_padding
start_of_word += zero_padding
span_mask += zero_padding

# The XLM-Roberta tokenizer generates a segment_ids vector that separates the first sequence from the second.
# However, when this is passed in to the forward fn of the Roberta model, it throws an error since
Expand All @@ -500,7 +510,8 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks
"start_of_word": start_of_word,
"labels": labels,
"id": sample_id,
"seq_2_start_t": seq_2_start_t}
"seq_2_start_t": seq_2_start_t,
"span_mask": span_mask}
return [feature_dict]


Expand Down
2 changes: 2 additions & 0 deletions farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ def _sample_to_features(self, sample) -> dict:
max_seq_len=self.max_seq_len,
sp_toks_start=self.sp_toks_start,
sp_toks_mid=self.sp_toks_mid,
sp_toks_end=self.sp_toks_end,
max_answers=self.max_answers)
return features

Expand Down Expand Up @@ -1572,6 +1573,7 @@ def _sample_to_features(self, sample: Sample) -> dict:
max_seq_len=self.max_seq_len,
sp_toks_start=self.sp_toks_start,
sp_toks_mid=self.sp_toks_mid,
sp_toks_end=self.sp_toks_end,
answer_type_list=self.answer_type_list,
max_answers=self.max_answers)
return features
Expand Down
96 changes: 30 additions & 66 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,8 @@ def logits_to_loss(self, logits, labels, **kwargs):
per_sample_loss = (start_loss + end_loss) / 2
return per_sample_loss

def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, max_answer_length=1000, **kwargs):
def logits_to_preds(self, logits, span_mask, start_of_word,
seq_2_start_t, max_answer_length=1000, **kwargs):
"""
Get the predicted index of start and end token of the answer. Note that the output is at token level
and not word level. Note also that these logits correspond to the tokens of a sample
Expand All @@ -1077,7 +1078,6 @@ def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, ma
# Calculate a few useful variables
batch_size = start_logits.size()[0]
max_seq_len = start_logits.shape[1] # target dim
n_non_padding = torch.sum(padding_mask, dim=1)

# get scores for all combinations of start and end logits => candidate answers
start_matrix = start_logits.unsqueeze(2).expand(-1, -1, max_seq_len)
Expand All @@ -1087,20 +1087,26 @@ def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, ma
# disqualify answers where end < start
# (set the lower triangular matrix to low value, excluding diagonal)
indices = torch.tril_indices(max_seq_len, max_seq_len, offset=-1, device=start_end_matrix.device)
start_end_matrix[:, indices[0][:], indices[1][:]] = -999
start_end_matrix[:, indices[0][:], indices[1][:]] = -888

# disqualify answers where start=0, but end != 0
start_end_matrix[:, 0, 1:] = -999

# TODO continue vectorization of valid_answer_idxs
# # disqualify where answers < seq_2_start_t and idx != 0
# # disqualify where answer falls into padding
# # seq_2_start_t can be different when 2 different questions are handled within one batch
# # n_non_padding can be different on sample level, too
# for i in range(batch_size):
# start_end_matrix[i, 1:seq_2_start_t[i], 1:seq_2_start_t[i]] = -888
# start_end_matrix[i, n_non_padding[i]-1:, n_non_padding[i]-1:] = -777
# disqualify answers where answer span is greater than max_answer_length
# (set the upper triangular matrix to low value, excluding diagonal)
indices_long_span = torch.triu_indices(max_seq_len, max_seq_len, offset=max_answer_length, device=start_end_matrix.device)
start_end_matrix[:, indices_long_span[0][:], indices_long_span[1][:]] = -777

# disqualify answers where start=0, but end != 0
start_end_matrix[:, 0, 1:] = -666

# Turn 1d span_mask vectors into 2d span_mask along 2 different axes
# span mask has:
# 0 for every position that is never a valid start or end index (question tokens, mid and end special tokens, padding)
# 1 everywhere else
span_mask_start = span_mask.unsqueeze(2).expand(-1, -1, max_seq_len)
span_mask_end = span_mask.unsqueeze(1).expand(-1, max_seq_len, -1)
span_mask_2d = span_mask_start + span_mask_end
# disqualify spans where either start or end is on an invalid token
invalid_indices = torch.nonzero((span_mask_2d != 2), as_tuple=True)
start_end_matrix[invalid_indices[0][:], invalid_indices[1][:], invalid_indices[2][:]] = -999

# Sort the candidate answers by their score. Sorting happens on the flattened matrix.
# flat_sorted_indices.shape: (batch_size, max_seq_len^2, 1)
Expand All @@ -1114,20 +1120,16 @@ def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, ma
end_indices = flat_sorted_indices % max_seq_len
sorted_candidates = torch.cat((start_indices, end_indices), dim=2)

# Get the n_best candidate answers for each sample that are valid (via some heuristic checks)
# Get the n_best candidate answers for each sample
for sample_idx in range(batch_size):
sample_top_n = self.get_top_candidates(sorted_candidates[sample_idx],
start_end_matrix[sample_idx],
n_non_padding[sample_idx].item(),
max_answer_length,
seq_2_start_t[sample_idx].item(),
sample_idx)
all_top_n.append(sample_top_n)

return all_top_n

def get_top_candidates(self, sorted_candidates, start_end_matrix,
n_non_padding, max_answer_length, seq_2_start_t, sample_idx):
def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx):
""" Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
Expand All @@ -1147,16 +1149,14 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix,
# Ignore no_answer scores which will be extracted later in this method
if start_idx == 0 and end_idx == 0:
continue
# Check that the candidate's indices are valid and save them if they are
if self.valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_2_start_t):
score = start_end_matrix[start_idx, end_idx].item()
top_candidates.append(QACandidate(offset_answer_start=start_idx,
offset_answer_end=end_idx,
score=score,
answer_type="span",
offset_unit="token",
aggregation_level="passage",
passage_id=sample_idx))
score = start_end_matrix[start_idx, end_idx].item()
top_candidates.append(QACandidate(offset_answer_start=start_idx,
offset_answer_end=end_idx,
score=score,
answer_type="span",
offset_unit="token",
aggregation_level="passage",
passage_id=sample_idx))

no_answer_score = start_end_matrix[0, 0].item()
top_candidates.append(QACandidate(offset_answer_start=0,
Expand All @@ -1169,42 +1169,6 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix,

return top_candidates

@staticmethod
def valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_2_start_t):
""" Returns True if the supplied index span is a valid prediction. The indices being provided
should be on sample/passage level (special tokens + question_tokens + passag_tokens)
and not document level"""

# This function can seriously slow down inferencing and eval. In the future this function will be completely vectorized
# Continue if start or end label points to a padding token
if start_idx < seq_2_start_t and start_idx != 0:
return False
if end_idx < seq_2_start_t and end_idx != 0:
return False
# The -1 is to stop the idx falling on a final special token
# TODO: this makes the assumption that there is a special token that comes at the end of the second sequence
if start_idx >= n_non_padding - 1:
return False
if end_idx >= n_non_padding - 1:
return False

# # Check if start comes after end
# # Handled on matrix level by: start_end_matrix[:, indices[0][1:], indices[1][1:]] = -999
# if end_idx < start_idx:
# return False

# # If one of the two indices is 0, the other must also be 0
# # Handled on matrix level by setting: start_end_matrix[:, 0, 1:] = -999
# if start_idx == 0 and end_idx != 0:
# return False
# if start_idx != 0 and end_idx == 0:
# return False

length = end_idx - start_idx + 1
if length > max_answer_length:
return False
return True

def formatted_preds(self, logits=None, preds=None, baskets=None, **kwargs):
""" Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
Expand Down
39 changes: 39 additions & 0 deletions test/benchmarks/question_answering_components.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

<html>
<head>
<script type="text/javascript" src="https://www.gstatic.com/charts/loader.js"></script>
<script type="text/javascript">
google.charts.load('current', {'packages':['bar']});
google.charts.setOnLoadCallback(drawChart);

function drawChart() {
var data = google.visualization.arrayToDataTable(
[
["Name", "preproc","language_model","prediction_head"],
['deepset/minilm-uncased-squad2', 12.277034912109375, 5.79623876953125, 1.5562604980468748], ['deepset/roberta-base-squad2', 12.380782958984376, 13.71148828125, 1.5372104492187502], ['deepset/bert-base-cased-squad2', 9.938722900390625, 15.864041992187499, 1.6085009765625005], ['deepset/bert-large-uncased-whole-word-masking-squad2', 9.692403808593749, 45.28969921875, 1.785435546875], ['deepset/xlm-roberta-large-squad2', 8.079997680664063, 48.489154296875, 1.974138671875]
]);

var options = {
chart: {
title: 'QA Model Speed Comparison',
subtitle: 'Time per Component',
},
bars: 'horizontal', // Required for Material Bar Charts.
isStacked: true,
height: 300,
legend: {position: 'top', maxLines: 3},
hAxis: {minValue: 0}

};

var chart = new google.charts.Bar(document.getElementById('barchart_material'));

chart.draw(data, google.charts.Bar.convertOptions(options));
}
</script>
</head>
<body>
<div id="barchart_material" style="width: 900px; height: 500px;"></div>
</body>
</html>

6 changes: 4 additions & 2 deletions test/benchmarks/question_answering_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
from tqdm import tqdm
import logging
import json
from datetime import date

logger = logging.getLogger(__name__)

logger = logging.getLogger(__name__)

task_type = "question_answering"
sample_file = "samples/question_answering_sample.txt"
questions_file = "samples/question_answering_questions.txt"
num_processes = 1
passages_per_char = 2400 / 1000000 # numerator is number of passages when 1mill chars paired with one of the questions, msl 384, doc stride 128
output_file = "results_component_test_24_09_20.csv"
date_str = date.today().strftime("%d_%m_%Y")
output_file = f"results_component_test_{date_str}.csv"

params = {
"modelname": ["deepset/bert-base-cased-squad2", "deepset/minilm-uncased-squad2", "deepset/roberta-base-squad2", "deepset/bert-large-uncased-whole-word-masking-squad2", "deepset/xlm-roberta-large-squad2"],
Expand Down
3 changes: 2 additions & 1 deletion test/test_input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MODEL = "roberta-base"
SP_TOKENS_START = 1
SP_TOKENS_MID = 2
SP_TOKENS_END = 1

def to_list(x):
try:
Expand All @@ -32,7 +33,7 @@ def test_sample_to_features_qa(caplog):
curr_id = "-".join([str(x) for x in features_gold["id"]])

s = Sample(id=curr_id, clear_text=clear_text, tokenized=tokenized)
features = sample_to_features_qa(s, tokenizer, max_seq_len, SP_TOKENS_START, SP_TOKENS_MID)[0]
features = sample_to_features_qa(s, tokenizer, max_seq_len, SP_TOKENS_START, SP_TOKENS_MID, SP_TOKENS_END)[0]
features = to_list(features)

keys = features_gold.keys()
Expand Down