Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 65 additions & 13 deletions backend/Generator/llm_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import re
import threading
import hashlib
from llama_cpp import Llama


Expand Down Expand Up @@ -39,9 +40,17 @@ def _prepare_text(self, input_text, max_words=3000):
input_text = " ".join(words[:max_words])
return input_text

def generate_short_questions(self, input_text, max_questions=4):
def generate_short_questions(self, input_text, max_questions=4, deterministic=False):
"""Generate short-answer questions from the given text."""
# Input validation
if not input_text or not isinstance(input_text, str):
return []

self._load_model()

# Compute seed BEFORE text truncation to ensure different inputs produce different seeds
seed_value = int(hashlib.sha256(input_text.encode()).hexdigest()[:8], 16) if deterministic else None

Comment thread
coderabbitai[bot] marked this conversation as resolved.
input_text = self._prepare_text(input_text)

prompt = (
Expand All @@ -52,6 +61,16 @@ def generate_short_questions(self, input_text, max_questions=4):
f"/no_think"
)

params = {
"max_tokens": 512,
"temperature": 0.7,
}

if deterministic:
params["temperature"] = 0.0
params["top_p"] = 1.0
params["seed"] = seed_value

response = self.llm.create_chat_completion(
messages=[
{
Expand All @@ -63,8 +82,7 @@ def generate_short_questions(self, input_text, max_questions=4):
"content": prompt,
},
],
max_tokens=512,
temperature=0.7,
**params
)

try:
Expand All @@ -78,9 +96,17 @@ def generate_short_questions(self, input_text, max_questions=4):
except (AttributeError, TypeError, ValueError):
return []

def generate_mcq_questions(self, input_text, max_questions=4):
def generate_mcq_questions(self, input_text, max_questions=4, deterministic=False):
"""Generate multiple-choice questions from the given text."""
# Input validation
if not input_text or not isinstance(input_text, str):
return []

self._load_model()

# Compute seed BEFORE text truncation to ensure different inputs produce different seeds
seed_value = int(hashlib.sha256(input_text.encode()).hexdigest()[:8], 16) if deterministic else None

input_text = self._prepare_text(input_text)

prompt = (
Expand All @@ -92,6 +118,16 @@ def generate_mcq_questions(self, input_text, max_questions=4):
f"/no_think"
)

params = {
"max_tokens": 1024,
"temperature": 0.7,
}

if deterministic:
params["temperature"] = 0.0
params["top_p"] = 1.0
params["seed"] = seed_value

response = self.llm.create_chat_completion(
messages=[
{
Expand All @@ -103,8 +139,7 @@ def generate_mcq_questions(self, input_text, max_questions=4):
"content": prompt,
},
],
max_tokens=1024,
temperature=0.7,
**params
)

try:
Expand All @@ -118,9 +153,17 @@ def generate_mcq_questions(self, input_text, max_questions=4):
except (AttributeError, TypeError, ValueError):
return []

def generate_boolean_questions(self, input_text, max_questions=4):
def generate_boolean_questions(self, input_text, max_questions=4, deterministic=False):
"""Generate true/false questions from the given text."""
# Input validation
if not input_text or not isinstance(input_text, str):
return []

self._load_model()

# Compute seed BEFORE text truncation to ensure different inputs produce different seeds
seed_value = int(hashlib.sha256(input_text.encode()).hexdigest()[:8], 16) if deterministic else None

input_text = self._prepare_text(input_text)

prompt = (
Expand All @@ -131,6 +174,16 @@ def generate_boolean_questions(self, input_text, max_questions=4):
f"/no_think"
)

params = {
"max_tokens": 512,
"temperature": 0.7,
}

if deterministic:
params["temperature"] = 0.0
params["top_p"] = 1.0
params["seed"] = seed_value

response = self.llm.create_chat_completion(
messages=[
{
Expand All @@ -142,8 +195,7 @@ def generate_boolean_questions(self, input_text, max_questions=4):
"content": prompt,
},
],
max_tokens=512,
temperature=0.7,
**params
)

try:
Expand All @@ -157,12 +209,12 @@ def generate_boolean_questions(self, input_text, max_questions=4):
except (AttributeError, TypeError, ValueError):
return []

def generate_all_questions(self, input_text, mcq_count=2, bool_count=2, short_count=2):
def generate_all_questions(self, input_text, mcq_count=2, bool_count=2, short_count=2, deterministic=False):
"""Generate a mix of all question types."""
questions = []

# Generate MCQs
mcqs = self.generate_mcq_questions(input_text, mcq_count)
mcqs = self.generate_mcq_questions(input_text, mcq_count, deterministic)
for mcq in mcqs:
questions.append({
"type": "mcq",
Expand All @@ -172,7 +224,7 @@ def generate_all_questions(self, input_text, mcq_count=2, bool_count=2, short_co
})

# Generate Boolean questions
bool_qs = self.generate_boolean_questions(input_text, bool_count)
bool_qs = self.generate_boolean_questions(input_text, bool_count, deterministic)
for bool_q in bool_qs:
questions.append({
"type": "boolean",
Expand All @@ -181,7 +233,7 @@ def generate_all_questions(self, input_text, mcq_count=2, bool_count=2, short_co
})

# Generate Short questions
short_qs = self.generate_short_questions(input_text, short_count)
short_qs = self.generate_short_questions(input_text, short_count, deterministic)
for short_q in short_qs:
questions.append({
"type": "short_answer",
Expand Down
28 changes: 24 additions & 4 deletions backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,13 @@ def get_shortq_llm():
input_text = data.get("input_text", "")
use_mediawiki = data.get("use_mediawiki", 0)
max_questions = data.get("max_questions", 4)
deterministic = data.get("deterministic", False)
Comment thread
piyush06singhal marked this conversation as resolved.

if not isinstance(deterministic, bool):
return jsonify({"error": "deterministic must be a boolean"}), 400

input_text = process_input_text(input_text, use_mediawiki)
questions = llm_generator.generate_short_questions(input_text, max_questions)
questions = llm_generator.generate_short_questions(input_text, max_questions, deterministic)
return jsonify({"output": questions})
except Exception as e:
app.logger.exception("Error in /get_shortq_llm: %s", e)
Expand All @@ -116,8 +121,13 @@ def get_mcq_llm():
input_text = data.get("input_text", "")
use_mediawiki = data.get("use_mediawiki", 0)
max_questions = data.get("max_questions", 4)
deterministic = data.get("deterministic", False)

if not isinstance(deterministic, bool):
return jsonify({"error": "deterministic must be a boolean"}), 400

input_text = process_input_text(input_text, use_mediawiki)
questions = llm_generator.generate_mcq_questions(input_text, max_questions)
questions = llm_generator.generate_mcq_questions(input_text, max_questions, deterministic)
return jsonify({"output": questions})
except Exception as e:
app.logger.exception("Error in /get_mcq_llm: %s", e)
Expand All @@ -131,8 +141,13 @@ def get_boolq_llm():
input_text = data.get("input_text", "")
use_mediawiki = data.get("use_mediawiki", 0)
max_questions = data.get("max_questions", 4)
deterministic = data.get("deterministic", False)

if not isinstance(deterministic, bool):
return jsonify({"error": "deterministic must be a boolean"}), 400

input_text = process_input_text(input_text, use_mediawiki)
questions = llm_generator.generate_boolean_questions(input_text, max_questions)
questions = llm_generator.generate_boolean_questions(input_text, max_questions, deterministic)
return jsonify({"output": questions})
except Exception as e:
app.logger.exception("Error in /get_boolq_llm: %s", e)
Expand All @@ -148,8 +163,13 @@ def get_problems_llm():
mcq_count = data.get("max_questions_mcq", 2)
bool_count = data.get("max_questions_boolq", 2)
short_count = data.get("max_questions_shortq", 2)
deterministic = data.get("deterministic", False)

if not isinstance(deterministic, bool):
return jsonify({"error": "deterministic must be a boolean"}), 400

input_text = process_input_text(input_text, use_mediawiki)
questions = llm_generator.generate_all_questions(input_text, mcq_count, bool_count, short_count)
questions = llm_generator.generate_all_questions(input_text, mcq_count, bool_count, short_count, deterministic)
return jsonify({"output": questions})
except Exception as e:
app.logger.exception("Error in /get_problems_llm: %s", e)
Expand Down