@@ -245,7 +245,7 @@ def samples_to_features_ner(
245245 return [feature_dict ]
246246
247247
248- def samples_to_features_bert_lm (sample , max_seq_len , tokenizer , next_sent_pred = True ):
248+ def samples_to_features_bert_lm (sample , max_seq_len , tokenizer , next_sent_pred = True , masked_lm_prob = 0.15 ):
249249 """
250250 Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
251251 IDs, LM labels, padding_mask, CLS and SEP tokens etc.
@@ -255,6 +255,8 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
255255 :param max_seq_len: Maximum length of sequence.
256256 :type max_seq_len: int
257257 :param tokenizer: Tokenizer
258+ :param masked_lm_prob: probability of masking a token
259+ :type masked_lm_prob: float
258260 :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
259261 """
260262
@@ -264,10 +266,10 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
264266
265267 # mask random words
266268 tokens_a , t1_label = mask_random_words (tokens_a , tokenizer .vocab ,
267- token_groups = sample .tokenized ["text_a" ]["start_of_word" ])
269+ token_groups = sample .tokenized ["text_a" ]["start_of_word" ], masked_lm_prob = masked_lm_prob )
268270
269271 tokens_b , t2_label = mask_random_words (tokens_b , tokenizer .vocab ,
270- token_groups = sample .tokenized ["text_b" ]["start_of_word" ])
272+ token_groups = sample .tokenized ["text_b" ]["start_of_word" ], masked_lm_prob = masked_lm_prob )
271273
272274 if tokenizer .is_fast :
273275 # Detokenize input as fast tokenizer can't handle tokenized input
@@ -290,7 +292,7 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
290292 tokens_a = sample .tokenized ["text_a" ]["tokens" ]
291293 tokens_b = None
292294 tokens_a , t1_label = mask_random_words (tokens_a , tokenizer .vocab ,
293- token_groups = sample .tokenized ["text_a" ]["start_of_word" ])
295+ token_groups = sample .tokenized ["text_a" ]["start_of_word" ], masked_lm_prob = masked_lm_prob )
294296 if tokenizer .is_fast :
295297 # Detokenize input as fast tokenizer can't handle tokenized input
296298 tokens_a = " " .join (tokens_a )
0 commit comments