|
11 | 11 | import torch |
12 | 12 | from torch import nn |
13 | 13 | from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, NLLLoss |
14 | | - |
| 14 | +from torch.distributed import all_gather |
15 | 15 | from farm.data_handler.utils import is_json |
16 | | -from farm.utils import convert_iob_to_simple_tags, try_get |
| 16 | +from farm.utils import convert_iob_to_simple_tags, try_get, all_gather_list |
17 | 17 | from farm.modeling.predictions import QACandidate, QAPred |
18 | 18 |
|
19 | 19 | logger = logging.getLogger(__name__) |
@@ -1524,15 +1524,26 @@ class TextSimilarityHead(PredictionHead): |
1524 | 1524 | """ |
1525 | 1525 | Trains a head on predicting the similarity of two texts like in Dense Passage Retrieval. |
1526 | 1526 | """ |
1527 | | - def __init__(self, similarity_function="dot_product", **kwargs): |
| 1527 | + def __init__(self, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, **kwargs): |
| 1528 | + """ |
| 1529 | + Init the TextSimilarityHead. |
| 1530 | +
|
| 1531 | + :param similarity_function: Function to calculate similarity between queries and passage embeddings. |
| 1532 | + Choose either "dot_product" (Default) or "cosine". |
| 1533 | + :param global_loss_buffer_size: Buffer size for all_gather() in DDP. |
| 1534 | + Increase if errors like "encoded data exceeds max_size ..." come up |
| 1535 | +
|
| 1536 | + :param kwargs: |
| 1537 | + """ |
| 1538 | + |
1528 | 1539 | super(TextSimilarityHead, self).__init__() |
1529 | 1540 |
|
1530 | 1541 | self.similarity_function = similarity_function |
1531 | 1542 | self.loss_fct = NLLLoss(reduction="mean") |
1532 | 1543 | self.task_name = "text_similarity" |
1533 | 1544 | self.model_type = "text_similarity" |
1534 | 1545 | self.ph_output_type = "per_sequence" |
1535 | | - |
| 1546 | + self.global_loss_buffer_size = global_loss_buffer_size |
1536 | 1547 | self.generate_config() |
1537 | 1548 |
|
1538 | 1549 | @classmethod |
@@ -1627,15 +1638,56 @@ def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs): |
1627 | 1638 |
|
1628 | 1639 | :return: negative log likelihood loss from similarity scores |
1629 | 1640 | """ |
| 1641 | + |
| 1642 | + # Check if DDP is initialized |
| 1643 | + try: |
| 1644 | + rank = torch.distributed.get_rank() |
| 1645 | + except AssertionError: |
| 1646 | + rank = -1 |
| 1647 | + |
1630 | 1648 | # Prepare predicted scores |
1631 | 1649 | query_vectors, passage_vectors = logits |
1632 | | - softmax_scores = self._embeddings_to_scores(query_vectors, passage_vectors) |
1633 | 1650 |
|
1634 | 1651 | # Prepare Labels |
1635 | 1652 | lm_label_ids = kwargs.get(self.label_tensor_name) |
1636 | 1653 | positive_idx_per_question = torch.nonzero((lm_label_ids.view(-1) == 1), as_tuple=False) |
1637 | | - #TODO gather global tensors from all nodes for DDP |
1638 | | - global_positive_idx_per_question = positive_idx_per_question |
| 1654 | + |
| 1655 | + # Gather global embeddings from all distributed nodes (DDP) |
| 1656 | + if rank != -1: |
| 1657 | + q_vector_to_send = torch.empty_like(query_vectors).cpu().copy_(query_vectors).detach_() |
| 1658 | + p_vector_to_send = torch.empty_like(passage_vectors).cpu().copy_(passage_vectors).detach_() |
| 1659 | + |
| 1660 | + global_question_passage_vectors = all_gather_list( |
| 1661 | + [q_vector_to_send, p_vector_to_send, positive_idx_per_question], |
| 1662 | + max_size=self.global_loss_buffer_size) |
| 1663 | + |
| 1664 | + global_query_vectors = [] |
| 1665 | + global_passage_vectors = [] |
| 1666 | + global_positive_idx_per_question = [] |
| 1667 | + total_passages = 0 |
| 1668 | + for i, item in enumerate(global_question_passage_vectors): |
| 1669 | + q_vector, p_vectors, positive_idx = item |
| 1670 | + |
| 1671 | + if i != rank: |
| 1672 | + global_query_vectors.append(q_vector.to(query_vectors.device)) |
| 1673 | + global_passage_vectors.append(p_vectors.to(passage_vectors.device)) |
| 1674 | + global_positive_idx_per_question.extend([v + total_passages for v in positive_idx]) |
| 1675 | + else: |
| 1676 | + global_query_vectors.append(query_vectors) |
| 1677 | + global_passage_vectors.append(passage_vectors) |
| 1678 | + global_positive_idx_per_question.extend([v + total_passages for v in positive_idx_per_question]) |
| 1679 | + total_passages += p_vectors.size(0) |
| 1680 | + |
| 1681 | + global_query_vectors = torch.cat(global_query_vectors, dim=0) |
| 1682 | + global_passage_vectors = torch.cat(global_passage_vectors, dim=0) |
| 1683 | + global_positive_idx_per_question = torch.LongTensor(global_positive_idx_per_question) |
| 1684 | + else: |
| 1685 | + global_query_vectors = query_vectors |
| 1686 | + global_passage_vectors = passage_vectors |
| 1687 | + global_positive_idx_per_question = positive_idx_per_question |
| 1688 | + |
| 1689 | + # Get similarity scores |
| 1690 | + softmax_scores = self._embeddings_to_scores(global_query_vectors, global_passage_vectors) |
1639 | 1691 | targets = global_positive_idx_per_question.squeeze(-1).to(softmax_scores.device) |
1640 | 1692 |
|
1641 | 1693 | # Calculate loss |
@@ -1664,7 +1716,9 @@ def prepare_labels(self, **kwargs): |
1664 | 1716 | """ |
1665 | 1717 | label_ids = kwargs.get(self.label_tensor_name) |
1666 | 1718 | labels = torch.zeros(label_ids.size(0), label_ids.numel()) |
1667 | | - positive_indices = (label_ids.view(-1) == 1).nonzero() |
| 1719 | + |
| 1720 | + positive_indices = torch.nonzero(label_ids.view(-1) == 1, as_tuple=False) |
| 1721 | + |
1668 | 1722 | for i, indx in enumerate(positive_indices): |
1669 | 1723 | labels[i, indx.item()] = 1 |
1670 | 1724 | return labels |
|
0 commit comments