Skip to content

Commit 4f88d80

Browse files
authored
Merge pull request #1990 from Giskard-AI/GSK-3609-Avoid-redundant-questions-in-data-generation
GSK-3609 Avoid redundant questions in data generation
2 parents 9fe01f7 + 6498980 commit 4f88d80

File tree

8 files changed

+59
-8
lines changed

8 files changed

+59
-8
lines changed

giskard/rag/knowledge_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,17 @@ def get_failure_plot(self, question_evaluation: Sequence[dict] = None):
300300
def get_random_document(self):
301301
return self._rng.choice(self._documents)
302302

303+
def get_random_documents(self, n: int, with_replacement=False):
304+
if with_replacement:
305+
return list(self._rng.choice(self._documents, n, replace=True))
306+
307+
docs = list(self._rng.choice(self._documents, min(n, len(self._documents)), replace=False))
308+
309+
if len(docs) <= n:
310+
docs.extend(self._rng.choice(self._documents, n - len(docs), replace=True))
311+
312+
return docs
313+
303314
def get_neighbors(self, seed_document: Document, n_neighbors: int = 4, similarity_threshold: float = 0.2):
304315
seed_embedding = seed_document.embeddings
305316

giskard/rag/question_generators/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ class GenerateFromSingleQuestionMixin:
5252
_question_type: str
5353

5454
def generate_questions(self, knowledge_base: KnowledgeBase, num_questions: int, *args, **kwargs) -> Iterator[Dict]:
55-
for _ in range(num_questions):
55+
docs = knowledge_base.get_random_documents(num_questions)
56+
57+
for doc in docs:
5658
try:
57-
yield self.generate_single_question(knowledge_base, *args, **kwargs)
59+
yield self.generate_single_question(knowledge_base, *args, **kwargs, seed_document=doc)
5860
except Exception as e: # @TODO: specify exceptions
5961
logger.error(f"Encountered error in question generation: {e}. Skipping.")
6062
logger.exception(e)

giskard/rag/question_generators/double_questions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ class DoubleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio
9292
_question_type = "double"
9393

9494
def generate_single_question(
95-
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
95+
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
9696
) -> QuestionSample:
97-
seed_document = knowledge_base.get_random_document()
97+
seed_document = seed_document or knowledge_base.get_random_document()
9898
context_documents = knowledge_base.get_neighbors(
9999
seed_document, self._context_neighbors, self._context_similarity_threshold
100100
)

giskard/rag/question_generators/oos_questions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class OutOfScopeGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestionGene
6868
_question_type = "out of scope"
6969

7070
def generate_single_question(
71-
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
71+
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
7272
) -> QuestionSample:
7373
"""
7474
Generate a question from a list of context documents.
@@ -87,7 +87,7 @@ def generate_single_question(
8787
Tuple[dict, dict]
8888
The generated question and the metadata of the question.
8989
"""
90-
seed_document = knowledge_base.get_random_document()
90+
seed_document = seed_document or knowledge_base.get_random_document()
9191

9292
context_documents = knowledge_base.get_neighbors(
9393
seed_document, self._context_neighbors, self._context_similarity_threshold

giskard/rag/question_generators/simple_questions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,13 @@ class SimpleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio
6666

6767
_question_type = "simple"
6868

69-
def generate_single_question(self, knowledge_base: KnowledgeBase, agent_description: str, language: str) -> dict:
69+
def generate_single_question(
70+
self,
71+
knowledge_base: KnowledgeBase,
72+
agent_description: str,
73+
language: str,
74+
seed_document=None,
75+
) -> dict:
7076
"""
7177
Generate a question from a list of context documents.
7278
@@ -80,7 +86,8 @@ def generate_single_question(self, knowledge_base: KnowledgeBase, agent_descript
8086
QuestionSample
8187
The generated question and the metadata of the question.
8288
"""
83-
seed_document = knowledge_base.get_random_document()
89+
seed_document = seed_document or knowledge_base.get_random_document()
90+
8491
context_documents = knowledge_base.get_neighbors(
8592
seed_document, self._context_neighbors, self._context_similarity_threshold
8693
)

tests/rag/test_knowledge_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,32 @@
88
from giskard.rag.knowledge_base import KnowledgeBase
99

1010

11+
def test_knowledge_base_get_random_documents():
12+
llm_client = Mock()
13+
embeddings = Mock()
14+
embeddings.embed.side_effect = [np.random.rand(5, 10), np.random.rand(3, 10)]
15+
16+
kb = KnowledgeBase.from_pandas(
17+
df=pd.DataFrame({"text": ["This is a test string"] * 5}), llm_client=llm_client, embedding_model=embeddings
18+
)
19+
20+
# Test when k is smaller than the number of documents
21+
docs = kb.get_random_documents(3)
22+
assert len(docs) == 3
23+
# Check that all document IDs are unique
24+
assert len(set(doc.id for doc in docs)) == len(docs)
25+
26+
# Test when k is equal to the number of documents
27+
docs = kb.get_random_documents(5)
28+
assert len(docs) == 5
29+
assert all([doc == kb[doc.id] for doc in docs])
30+
31+
# Test when k is larger than the number of documents
32+
docs = kb.get_random_documents(10)
33+
assert len(docs) == 10
34+
assert all([doc == kb[doc.id] for doc in docs])
35+
36+
1137
def test_knowledge_base_creation_from_df():
1238
dimension = 8
1339
df = pd.DataFrame(["This is a test string"] * 5)

tests/rag/test_question_generators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_simple_question_generation():
3333
Document(dict(content="Milk is produced by cows, goats or sheep.")),
3434
]
3535
knowledge_base.get_random_document = Mock(return_value=documents[0])
36+
knowledge_base.get_random_documents = Mock(return_value=documents)
3637
knowledge_base.get_neighbors = Mock(return_value=documents)
3738

3839
question_generator = SimpleQuestionsGenerator(llm_client=llm_client)
@@ -212,6 +213,7 @@ def test_double_question_generation():
212213
Document(dict(content="Milk is produced by cows, goats or sheep.")),
213214
]
214215
knowledge_base.get_random_document = Mock(return_value=documents[0])
216+
knowledge_base.get_random_documents = Mock(return_value=documents)
215217
knowledge_base.get_neighbors = Mock(return_value=documents)
216218

217219
question_generator = DoubleQuestionsGenerator(llm_client=llm_client)
@@ -304,6 +306,7 @@ def test_oos_question_generation():
304306
dict(content="Paul Graham liked to buy a baguette every day at the local market."), doc_id="1"
305307
)
306308
)
309+
knowledge_base.get_random_documents = Mock(return_value=documents)
307310
knowledge_base.get_neighbors = Mock(return_value=documents)
308311

309312
question_generator = OutOfScopeGenerator(llm_client=llm_client)

tests/rag/test_testset_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def test_question_generation_fail(caplog):
177177
knowledge_base.__getitem__ = lambda obj, idx: documents[0]
178178
knowledge_base.topics = ["Cheese", "Ski"]
179179

180+
knowledge_base.get_random_documents = Mock(return_value=documents)
181+
180182
simple_gen = Mock()
181183
simple_gen.generate_questions.return_value = [q1, q2]
182184
failing_gen = SimpleQuestionsGenerator(llm_client=Mock())

0 commit comments

Comments
 (0)