Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.

Commit 2fabc31

Browse files
authored
Add MultiGPU support for DPR Training via DDP (#619)
* WIP initial global sync for loss * rename vars * wip ddp * fix gathering of tensors for DDP * fix vocab_size check. fix example script for DDP. fix check of rank in PH. * fix typo. fix deprecation warning
1 parent 9a910ff commit 2fabc31

7 files changed

Lines changed: 166 additions & 24 deletions

File tree

examples/dpr_encoder.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
import pprint
55
from pathlib import Path
6+
import argparse
7+
68

79
from farm.data_handler.data_silo import DataSilo
810
from farm.data_handler.processor import TextSimilarityProcessor
@@ -15,6 +17,16 @@
1517
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings
1618
from farm.eval import Evaluator
1719

20+
def parse_arguments():
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument("--local_rank",
23+
type=int,
24+
default=-1,
25+
help="local_rank for distributed training on GPUs")
26+
args = parser.parse_args()
27+
return args
28+
29+
1830
def dense_passage_retrieval():
1931
logging.basicConfig(
2032
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -29,9 +41,9 @@ def dense_passage_retrieval():
2941
########## Settings
3042
##########################
3143
set_all_seeds(seed=42)
32-
device, n_gpu = initialize_device_settings(use_cuda=True)
33-
batch_size = 2
44+
batch_size = 4
3445
n_epochs = 3
46+
distributed = False # enable for multi GPU training via DDP
3547
evaluate_every = 1000
3648
question_lang_model = "facebook/dpr-question_encoder-single-nq-base"
3749
passage_lang_model = "facebook/dpr-ctx_encoder-single-nq-base"
@@ -43,7 +55,11 @@ def dense_passage_retrieval():
4355
train_filename = "nq-train.json"
4456
dev_filename = "nq-dev.json"
4557
test_filename = "nq-dev.json"
46-
max_samples = None #load a smaller dataset (e.g. for debugging)
58+
max_samples = None # load a smaller dataset (e.g. for debugging)
59+
60+
# For multi GPU Training via DDP we need to get the local rank
61+
args = parse_arguments()
62+
device, n_gpu = initialize_device_settings(use_cuda=True, local_rank=args.local_rank)
4763

4864
# 1.Create question and passage tokenizers
4965
query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model,
@@ -58,11 +74,11 @@ def dense_passage_retrieval():
5874
metric = "text_similarity_metric"
5975
processor = TextSimilarityProcessor(tokenizer=query_tokenizer,
6076
passage_tokenizer=passage_tokenizer,
61-
max_seq_len_query=256,
77+
max_seq_len_query=64,
6278
max_seq_len_passage=256,
6379
label_list=label_list,
6480
metric=metric,
65-
data_dir="data/retriever",
81+
data_dir="../data/retriever",
6682
train_filename=train_filename,
6783
dev_filename=dev_filename,
6884
test_filename=test_filename,
@@ -72,7 +88,7 @@ def dense_passage_retrieval():
7288

7389
# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
7490
# NOTE: In FARM, the dev set metrics differ from test set metrics in that they are calculated on a token level instead of a word level
75-
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)
91+
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=distributed)
7692

7793

7894
# 4. Create an BiAdaptiveModel+
@@ -104,7 +120,8 @@ def dense_passage_retrieval():
104120
n_batches=len(data_silo.loaders["train"]),
105121
n_epochs=n_epochs,
106122
grad_acc_steps=1,
107-
device=device
123+
device=device,
124+
distributed=distributed
108125
)
109126

110127
# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time

farm/data_handler/data_silo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def __init__(
5555
:type batch_size: int
5656
:param eval_batch_size: The size of batch that should be returned by the DataLoaders for the dev and test set.
5757
:type eval_batch_size: int
58-
:param distributed: Set to True if the program is running in a distributed setting.
58+
:param distributed: Set to True if you are running in a distributed evn, e.g. using DistributedDataParallel.
59+
The DataSilo will init the DataLoader with a DistributedSampler() to distribute batches.
5960
:type distributed: bool
6061
:param automatic_loading: Set to False, if you don't want to automatically load data at initialization.
6162
:type automatic_loading: bool

farm/evaluation/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def text_similarity_avg_ranks(preds, labels):
240240
241241
:param preds: list of numpy arrays of dimension n1 x n2 containing n2 predicted ranks for n1 sequences/queries
242242
:type preds: List of numpy array containing similarity scores for each sequence in batch
243-
:param labels: list of arrays of dimension n1 x n2 where each array contains n2 labels(0/1) dindicating whether the sequence/passage is a positive(1) passage or hard_negative(0) passage
243+
:param labels: list of arrays of dimension n1 x n2 where each array contains n2 labels(0/1) indicating whether the sequence/passage is a positive(1) passage or hard_negative(0) passage
244244
:type labels: List of list containing values(0/1)
245245
246246
:return: average predicted ranks of positive sequence/passage for each sample/query

farm/modeling/biadaptive_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def forward(self, **kwargs):
353353
:return: all logits as torch.tensor or multiple tensors.
354354
"""
355355

356-
# Run forward pass of language model
356+
# Run forward pass of both language models
357357
pooled_output = self.forward_lm(**kwargs)
358358

359359
# Run forward pass of (multiple) prediction heads using the output from above

farm/modeling/prediction_head.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import torch
1212
from torch import nn
1313
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, NLLLoss
14-
14+
from torch.distributed import all_gather
1515
from farm.data_handler.utils import is_json
16-
from farm.utils import convert_iob_to_simple_tags, try_get
16+
from farm.utils import convert_iob_to_simple_tags, try_get, all_gather_list
1717
from farm.modeling.predictions import QACandidate, QAPred
1818

1919
logger = logging.getLogger(__name__)
@@ -1524,15 +1524,26 @@ class TextSimilarityHead(PredictionHead):
15241524
"""
15251525
Trains a head on predicting the similarity of two texts like in Dense Passage Retrieval.
15261526
"""
1527-
def __init__(self, similarity_function="dot_product", **kwargs):
1527+
def __init__(self, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, **kwargs):
1528+
"""
1529+
Init the TextSimilarityHead.
1530+
1531+
:param similarity_function: Function to calculate similarity between queries and passage embeddings.
1532+
Choose either "dot_product" (Default) or "cosine".
1533+
:param global_loss_buffer_size: Buffer size for all_gather() in DDP.
1534+
Increase if errors like "encoded data exceeds max_size ..." come up
1535+
1536+
:param kwargs:
1537+
"""
1538+
15281539
super(TextSimilarityHead, self).__init__()
15291540

15301541
self.similarity_function = similarity_function
15311542
self.loss_fct = NLLLoss(reduction="mean")
15321543
self.task_name = "text_similarity"
15331544
self.model_type = "text_similarity"
15341545
self.ph_output_type = "per_sequence"
1535-
1546+
self.global_loss_buffer_size = global_loss_buffer_size
15361547
self.generate_config()
15371548

15381549
@classmethod
@@ -1627,15 +1638,56 @@ def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
16271638
16281639
:return: negative log likelihood loss from similarity scores
16291640
"""
1641+
1642+
# Check if DDP is initialized
1643+
try:
1644+
rank = torch.distributed.get_rank()
1645+
except AssertionError:
1646+
rank = -1
1647+
16301648
# Prepare predicted scores
16311649
query_vectors, passage_vectors = logits
1632-
softmax_scores = self._embeddings_to_scores(query_vectors, passage_vectors)
16331650

16341651
# Prepare Labels
16351652
lm_label_ids = kwargs.get(self.label_tensor_name)
16361653
positive_idx_per_question = torch.nonzero((lm_label_ids.view(-1) == 1), as_tuple=False)
1637-
#TODO gather global tensors from all nodes for DDP
1638-
global_positive_idx_per_question = positive_idx_per_question
1654+
1655+
# Gather global embeddings from all distributed nodes (DDP)
1656+
if rank != -1:
1657+
q_vector_to_send = torch.empty_like(query_vectors).cpu().copy_(query_vectors).detach_()
1658+
p_vector_to_send = torch.empty_like(passage_vectors).cpu().copy_(passage_vectors).detach_()
1659+
1660+
global_question_passage_vectors = all_gather_list(
1661+
[q_vector_to_send, p_vector_to_send, positive_idx_per_question],
1662+
max_size=self.global_loss_buffer_size)
1663+
1664+
global_query_vectors = []
1665+
global_passage_vectors = []
1666+
global_positive_idx_per_question = []
1667+
total_passages = 0
1668+
for i, item in enumerate(global_question_passage_vectors):
1669+
q_vector, p_vectors, positive_idx = item
1670+
1671+
if i != rank:
1672+
global_query_vectors.append(q_vector.to(query_vectors.device))
1673+
global_passage_vectors.append(p_vectors.to(passage_vectors.device))
1674+
global_positive_idx_per_question.extend([v + total_passages for v in positive_idx])
1675+
else:
1676+
global_query_vectors.append(query_vectors)
1677+
global_passage_vectors.append(passage_vectors)
1678+
global_positive_idx_per_question.extend([v + total_passages for v in positive_idx_per_question])
1679+
total_passages += p_vectors.size(0)
1680+
1681+
global_query_vectors = torch.cat(global_query_vectors, dim=0)
1682+
global_passage_vectors = torch.cat(global_passage_vectors, dim=0)
1683+
global_positive_idx_per_question = torch.LongTensor(global_positive_idx_per_question)
1684+
else:
1685+
global_query_vectors = query_vectors
1686+
global_passage_vectors = passage_vectors
1687+
global_positive_idx_per_question = positive_idx_per_question
1688+
1689+
# Get similarity scores
1690+
softmax_scores = self._embeddings_to_scores(global_query_vectors, global_passage_vectors)
16391691
targets = global_positive_idx_per_question.squeeze(-1).to(softmax_scores.device)
16401692

16411693
# Calculate loss
@@ -1664,7 +1716,9 @@ def prepare_labels(self, **kwargs):
16641716
"""
16651717
label_ids = kwargs.get(self.label_tensor_name)
16661718
labels = torch.zeros(label_ids.size(0), label_ids.numel())
1667-
positive_indices = (label_ids.view(-1) == 1).nonzero()
1719+
1720+
positive_indices = torch.nonzero(label_ids.view(-1) == 1, as_tuple=False)
1721+
16681722
for i, indx in enumerate(positive_indices):
16691723
labels[i, indx.item()] = 1
16701724
return labels

farm/train.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,8 @@ def train(self):
250250

251251
# connect the prediction heads with the right output from processor
252252
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
253-
# Check that the tokenizer fits the language model
254-
#TODO: make this compliant for DP / DDP where the model class is wrapped
255-
if self.model._get_name() == 'BiAdaptiveModel':
253+
# Check that the tokenizer(s) fits the language model(s)
254+
if hasattr(self.model, "language_model2"):
256255
self.model.verify_vocab_size(vocab_size1=len(self.data_silo.processor.tokenizer),
257256
vocab_size2=len(self.data_silo.processor.passage_tokenizer))
258257
else:
@@ -297,7 +296,6 @@ def train(self):
297296

298297
# Move batch of samples to device
299298
batch = {key: batch[key].to(self.device) for key in batch}
300-
301299
# Forward & backward pass through model
302300
logits = self.model.forward(**batch)
303301
per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
@@ -367,7 +365,7 @@ def train(self):
367365
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
368366

369367
# Eval on test set
370-
if self.evaluator_test:
368+
if self.evaluator_test and self.local_rank in [0, -1]:
371369
test_data_loader = self.data_silo.get_data_loader("test")
372370
if test_data_loader is not None:
373371
evaluator_test = Evaluator(

farm/utils.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import signal
77
import numpy as np
88
import torch
9+
import torch.distributed as dist
910
from requests.exceptions import ConnectionError
1011
from torch import multiprocessing as mp
1112
import mlflow
1213
from copy import deepcopy
1314
import pandas as pd
1415
from tqdm import tqdm
1516
import time
16-
17+
import pickle
1718

1819
from farm.visual.ascii.images import WELCOME_BARN, WORKER_M, WORKER_F, WORKER_X
1920

@@ -475,3 +476,74 @@ def calc_duration(self, start, end):
475476
return start.elapsed_time(end) / 1000
476477
else:
477478
return end - start
479+
480+
481+
# DDP utils
482+
483+
def all_reduce(tensor, group=None):
484+
if group is None:
485+
group = dist.group.WORLD
486+
return dist.all_reduce(tensor, group=group)
487+
488+
489+
def all_gather_list(data, group=None, max_size=16384):
490+
"""Gathers arbitrary data from all nodes into a list.
491+
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
492+
data. Note that *data* must be picklable.
493+
Args:
494+
data (Any): data from the local worker to be gathered on other workers
495+
group (optional): group of the collective
496+
"""
497+
SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size
498+
499+
enc = pickle.dumps(data)
500+
enc_size = len(enc)
501+
502+
if enc_size + SIZE_STORAGE_BYTES > max_size:
503+
raise ValueError(
504+
'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size))
505+
506+
rank = dist.get_rank()
507+
world_size = dist.get_world_size()
508+
buffer_size = max_size * world_size
509+
510+
if not hasattr(all_gather_list, '_buffer') or \
511+
all_gather_list._buffer.numel() < buffer_size:
512+
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
513+
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
514+
515+
buffer = all_gather_list._buffer
516+
buffer.zero_()
517+
cpu_buffer = all_gather_list._cpu_buffer
518+
519+
assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format(
520+
256 ** SIZE_STORAGE_BYTES)
521+
522+
size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big')
523+
524+
cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes))
525+
cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc))
526+
527+
start = rank * max_size
528+
size = enc_size + SIZE_STORAGE_BYTES
529+
buffer[start: start + size].copy_(cpu_buffer[:size])
530+
531+
all_reduce(buffer, group=group)
532+
533+
try:
534+
result = []
535+
for i in range(world_size):
536+
out_buffer = buffer[i * max_size: (i + 1) * max_size]
537+
size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big')
538+
if size > 0:
539+
result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist())))
540+
return result
541+
except pickle.UnpicklingError:
542+
raise Exception(
543+
'Unable to unpickle data from other workers. all_gather_list requires all '
544+
'workers to enter the function together, so this error usually indicates '
545+
'that the workers have fallen out of sync somehow. Workers can fall out of '
546+
'sync if one of them runs out of memory, or if there are other conditions '
547+
'in your training script that can cause one worker to finish an epoch '
548+
'while other workers are still iterating over their portions of the data.'
549+
)

0 commit comments

Comments
 (0)