Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.

Commit d3eef8c

Browse files
authored
Adding probability of masking a token parameter for LM task (#630)
1 parent 74aa08d commit d3eef8c

3 files changed

Lines changed: 14 additions & 7 deletions

File tree

farm/data_handler/input_features.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def samples_to_features_ner(
245245
return [feature_dict]
246246

247247

248-
def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=True):
248+
def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=True, masked_lm_prob=0.15):
249249
"""
250250
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
251251
IDs, LM labels, padding_mask, CLS and SEP tokens etc.
@@ -255,6 +255,8 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
255255
:param max_seq_len: Maximum length of sequence.
256256
:type max_seq_len: int
257257
:param tokenizer: Tokenizer
258+
:param masked_lm_prob: probability of masking a token
259+
:type masked_lm_prob: float
258260
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
259261
"""
260262

@@ -264,10 +266,10 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
264266

265267
# mask random words
266268
tokens_a, t1_label = mask_random_words(tokens_a, tokenizer.vocab,
267-
token_groups=sample.tokenized["text_a"]["start_of_word"])
269+
token_groups=sample.tokenized["text_a"]["start_of_word"], masked_lm_prob=masked_lm_prob)
268270

269271
tokens_b, t2_label = mask_random_words(tokens_b, tokenizer.vocab,
270-
token_groups=sample.tokenized["text_b"]["start_of_word"])
272+
token_groups=sample.tokenized["text_b"]["start_of_word"], masked_lm_prob=masked_lm_prob)
271273

272274
if tokenizer.is_fast:
273275
# Detokenize input as fast tokenizer can't handle tokenized input
@@ -290,7 +292,7 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
290292
tokens_a = sample.tokenized["text_a"]["tokens"]
291293
tokens_b = None
292294
tokens_a, t1_label = mask_random_words(tokens_a, tokenizer.vocab,
293-
token_groups=sample.tokenized["text_a"]["start_of_word"])
295+
token_groups=sample.tokenized["text_a"]["start_of_word"], masked_lm_prob=masked_lm_prob)
294296
if tokenizer.is_fast:
295297
# Detokenize input as fast tokenizer can't handle tokenized input
296298
tokens_a = " ".join(tokens_a)

farm/data_handler/processor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,8 @@ def __init__(
845845
next_sent_pred_style="sentence",
846846
max_docs=None,
847847
proxies=None,
848+
masked_lm_prob=0.15,
849+
848850
**kwargs
849851
):
850852
"""
@@ -885,6 +887,8 @@ def __init__(
885887
:param proxies: proxy configuration to allow downloads of remote datasets.
886888
Format as in "requests" library: https://2.python-requests.org//en/latest/user/advanced/#proxies
887889
:type proxies: dict
890+
:param masked_lm_prob: probability of masking a token
891+
:type masked_lm_prob: float
888892
:param kwargs: placeholder for passing generic parameters
889893
:type kwargs: object
890894
"""
@@ -910,6 +914,8 @@ def __init__(
910914
self.add_task("lm", "acc", list(self.tokenizer.vocab) + added_tokens)
911915
if self.next_sent_pred:
912916
self.add_task("nextsentence", "acc", ["False", "True"])
917+
self.masked_lm_prob = masked_lm_prob
918+
913919

914920
def get_added_tokens(self):
915921
dictionary = self.tokenizer.added_tokens_encoder
@@ -1064,7 +1070,7 @@ def _dict_to_samples_no_next_sent(self, doc):
10641070
def _sample_to_features(self, sample) -> dict:
10651071
features = samples_to_features_bert_lm(
10661072
sample=sample, max_seq_len=self.max_seq_len, tokenizer=self.tokenizer,
1067-
next_sent_pred=self.next_sent_pred
1073+
next_sent_pred=self.next_sent_pred, masked_lm_prob=self.masked_lm_prob
10681074
)
10691075
return features
10701076

@@ -1205,7 +1211,6 @@ def __init__(
12051211
tasks={},
12061212
proxies=proxies
12071213
)
1208-
12091214
if metric and label_list:
12101215
self.add_task("question_answering", metric, label_list)
12111216
else:

test/test_lm_finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,4 +251,4 @@ def test_lm_finetuning_custom_vocab(caplog):
251251
assert isinstance(result[0]["vec"][0], np.float32)
252252

253253
if(__name__=="__main__"):
254-
test_lm_finetuning()
254+
test_lm_finetuning()

0 commit comments

Comments
 (0)