Skip to content

Commit 9ae96f4

Browse files
feat: Implement LLM call caching, fast-path classification, and concurrency control for message classification.
1 parent af351a4 commit 9ae96f4

2 files changed

Lines changed: 272 additions & 5 deletions

File tree

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import re
2+
import time
3+
import asyncio
4+
import logging
5+
from typing import Optional, Dict, Any
6+
from cachetools import TTLCache
7+
from collections import Counter
8+
from langchain_core.messages import HumanMessage
9+
import xxhash
10+
import hashlib
11+
import json
12+
13+
logger = logging.getLogger(__name__)
14+
15+
try:
16+
_HAS_XXHASH = True
17+
except Exception:
18+
xxhash = None
19+
_HAS_XXHASH = False
20+
21+
# Config
22+
CACHE_MAXSIZE = 4096
23+
CACHE_TTL_SECONDS = 60 * 60
24+
MAX_MESSAGE_LENGTH = 10000 # Max message length to process (prevents DoS via large payloads)
25+
26+
# Patterns for fast-path classification (concise to reduce memory)
27+
# Merge related intents into fewer regexes and add common Discord patterns
28+
_PATTERNS = {
29+
# common salutations
30+
"greeting": re.compile(r"^\s*(?:hi|hello|hey|good\s+morning|good\s+afternoon|good\s+evening)\b", re.I),
31+
# explicit help / action requests
32+
"action_request": re.compile(r".*\b(?:help|please\s+help|plz\s+help|need\s+help|support|assist|request)\b", re.I),
33+
# bug / error reports
34+
"bug_report": re.compile(r".*\b(?:bug|error|exception|stack\s*trace|crash|failed|traceback|segfault)\b", re.I),
35+
# thanks and short acknowledgements (shared fast-path)
36+
"thanks_ack": re.compile(r"^\s*(?:thanks|thank\s+you|thx|ty|ok|okay|got\s+it|roger|ack)\b", re.I),
37+
# modern short responses / slang that are non-actionable
38+
"slang": re.compile(r"^\s*(?:brb|lol|lmao|rofl|omg|wtf|smh|idk|np|yw|pls|plz|bump|ping|fyi|imo|idc)\b", re.I),
39+
# general intent bucket for optimization/performance/docs/feature keywords
40+
"intent_general": re.compile(
41+
r".*\b(?:optimi[sz]e|improve|speed\s*up|performance|memory|resource|efficient|documentation|docs|guide|tutorial|example|feature|suggest|idea)\b",
42+
re.I,
43+
),
44+
# Discord-specific: user mentions (@user)
45+
"discord_mention": re.compile(r"(?:<@!?\d+>|@\w+)\b"),
46+
# Channel mentions (#channel or <#123456>)
47+
"channel_mention": re.compile(r"(?:<#\d+>|#\w+)\b"),
48+
# Bot/CLI-like commands commonly used on Discord (prefix-based)
49+
"command": re.compile(r"^\s*(?:/|!|\?|\.|\$)[A-Za-z0-9_\-]+"),
50+
# Code snippets or blocks (inline or triple backticks)
51+
"code_block": re.compile(r"```[\s\S]*?```|`[^`]+`", re.S),
52+
# URLs (simple detection)
53+
"url": re.compile(r"https?://\S+|www\.\S+"),
54+
# GitHub/issue/PR references (#123, owner/repo#123, PR #123)
55+
"pr_issue_ref": re.compile(r"(?:\b#\d+\b|\b[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+#\d+\b|\bPR\s*#\d+\b)", re.I),
56+
# Emoji shortname like :emoji:
57+
"emoji_short": re.compile(r":[a-zA-Z0-9_+\-]+:"),
58+
}
59+
60+
# Simple deterministic classifications for the patterns
61+
# Keep mapping concise and reflect combined pattern keys
62+
_PATTERN_CLASSIFICATION = {
63+
"greeting": {"needs_devrel": False, "priority": "low", "reasoning": "greeting"},
64+
"thanks_ack": {"needs_devrel": False, "priority": "low", "reasoning": "thanks/acknowledgement"},
65+
"slang": {"needs_devrel": False, "priority": "low", "reasoning": "short/slang response"},
66+
"action_request": {"needs_devrel": True, "priority": "high", "reasoning": "explicit help/request keywords"},
67+
"bug_report": {"needs_devrel": True, "priority": "high", "reasoning": "error or bug report"},
68+
"integration": {"needs_devrel": True, "priority": "high", "reasoning": "Discord/GitHub/integration requests (OAuth, commands, threads, repo ops)"},
69+
"architecture": {"needs_devrel": True, "priority": "medium", "reasoning": "architecture/infra mentions (queues, DBs, LLMs)"},
70+
"intent_general": {"needs_devrel": True, "priority": "medium", "reasoning": "optimization/docs/feature requests"},
71+
72+
# Discord/GitHub specific quick classifications
73+
"discord_mention": {"needs_devrel": False, "priority": "low", "reasoning": "user mention"},
74+
"channel_mention": {"needs_devrel": False, "priority": "low", "reasoning": "channel mention"},
75+
"command": {"needs_devrel": False, "priority": "medium", "reasoning": "bot/CLI command invocation"},
76+
"code_block": {"needs_devrel": False, "priority": "low", "reasoning": "code snippet or block"},
77+
"url": {"needs_devrel": False, "priority": "low", "reasoning": "contains URL"},
78+
"pr_issue_ref": {"needs_devrel": True, "priority": "medium", "reasoning": "reference to issue or PR"},
79+
"emoji_short": {"needs_devrel": False, "priority": "low", "reasoning": "emoji shortname"},
80+
}
81+
82+
_cache = TTLCache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL_SECONDS)
83+
# In-flight calls to dedupe concurrent identical requests (bounded with TTL to prevent leaks)
84+
_inflight: TTLCache = TTLCache(maxsize=1000, ttl=120) # Max 1000 concurrent, 2min timeout
85+
86+
# Simple metrics
87+
metrics = Counter({"total": 0, "cache_hits": 0, "cache_misses": 0, "skipped_llm": 0})
88+
89+
90+
# Simple cache key generation
91+
def make_key(model: str, prompt: str, params: Dict[str, Any]) -> str:
92+
"""
93+
Create a stable cache key using XXHash128 for speed.
94+
- normalize prompt to reduce trivial differences
95+
- serialize params with sorted keys and compact separators
96+
- use blake2b as a fallback if xxhash unavailable
97+
"""
98+
norm_prompt = normalize_message(prompt)
99+
100+
# Serialize params once; for very large params consider hashing only relevant fields
101+
try:
102+
params_blob = json.dumps(params or {}, sort_keys=True, separators=(",", ":"), default=str).encode("utf-8")
103+
except Exception:
104+
params_blob = str(params).encode("utf-8")
105+
106+
payload = b"|".join([model.encode("utf-8"), norm_prompt.encode("utf-8"), params_blob])
107+
108+
# Use XXHash128 for better collision resistance (if available), otherwise fallback
109+
if _HAS_XXHASH:
110+
return xxhash.xxh3_128_hexdigest(payload)
111+
else:
112+
return hashlib.blake2b(payload, digest_size=16).hexdigest()
113+
114+
115+
def _compose_prompt_with_context(normalized: str, context_id: Optional[str]) -> str:
116+
if context_id:
117+
return f"{normalized}|ctx:{context_id}"
118+
return normalized
119+
120+
121+
def key_for_normalized(normalized: str, context_id: Optional[str], model: str, params: Dict[str, Any]) -> str:
122+
"""
123+
Compute cache key from a normalized message and optional context id.
124+
"""
125+
prompt = _compose_prompt_with_context(normalized, context_id)
126+
return make_key(model, prompt, params)
127+
128+
129+
def get_cached_by_normalized(normalized: str, context_id: Optional[str], model: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
130+
"""Retrieve cached payload for a normalized message + context."""
131+
key = key_for_normalized(normalized, context_id, model, params)
132+
return cache_get(key)
133+
134+
135+
def set_cached_by_normalized(normalized: str, context_id: Optional[str], model: str, params: Dict[str, Any], payload: Dict[str, Any]) -> None:
136+
"""Store payload for normalized message + context."""
137+
key = key_for_normalized(normalized, context_id, model, params)
138+
cache_set(key, payload)
139+
140+
141+
# Cache wrapper for LLM calls (async - uses llm.ainvoke)
142+
async def cached_llm_call(prompt: str, model: str, params: Dict[str, Any], llm):
143+
"""
144+
Cached wrapper for async LLM calls with:
145+
- fast-path simple pattern classification to avoid LLM cost
146+
- cache hit/miss metrics
147+
- in-flight deduplication so concurrent identical requests share one LLM call
148+
"""
149+
# Fast-path: simple deterministic classification (avoid LLM)
150+
normalized = normalize_message(prompt)
151+
simple = is_simple_message(normalized)
152+
if simple is not None:
153+
metrics["skipped_llm"] += 1
154+
return simple
155+
156+
metrics["total"] += 1
157+
key = make_key(model, prompt, params)
158+
159+
# Quick cache check
160+
cached = cache_get(key)
161+
if cached is not None:
162+
metrics["cache_hits"] += 1
163+
return cached
164+
165+
metrics["cache_misses"] += 1
166+
167+
# Deduplicate in-flight identical calls so only one LLM request is made
168+
loop = asyncio.get_running_loop()
169+
# Attempt to install a future atomically to dedupe concurrent callers
170+
future = loop.create_future()
171+
prev = _inflight.setdefault(key, future)
172+
if prev is not future:
173+
# another caller is in-flight; await its result/failure
174+
return await prev
175+
176+
# we are the owner; perform the fetch and set the future result/exception
177+
async def _owner_fetch():
178+
try:
179+
start = time.time()
180+
response = await llm.ainvoke([HumanMessage(content=prompt)])
181+
elapsed = time.time() - start
182+
# store response content or small payload rather than full object
183+
result = response.content if hasattr(response, "content") else response
184+
_cache[key] = result
185+
future.set_result(result)
186+
return result
187+
except asyncio.CancelledError:
188+
future.cancel()
189+
raise
190+
except Exception as e:
191+
future.set_exception(e)
192+
raise
193+
finally:
194+
# ensure inflight entry removed
195+
_inflight.pop(key, None)
196+
197+
# schedule owner fetch and await its result
198+
loop.create_task(_owner_fetch())
199+
return await future
200+
201+
def normalize_message(msg: str) -> str:
202+
"""Normalize message for caching. Truncates to MAX_MESSAGE_LENGTH to prevent DoS."""
203+
s = (msg or "")[:MAX_MESSAGE_LENGTH].strip().lower()
204+
s = re.sub(r"\s+", " ", s)
205+
return s
206+
207+
def is_simple_message(normalized: str) -> Optional[Dict[str, Any]]:
208+
for name, pattern in _PATTERNS.items():
209+
if pattern.match(normalized):
210+
return dict(_PATTERN_CLASSIFICATION[name], original_message=normalized)
211+
return None
212+
213+
def cache_get(key: str) -> Optional[Dict[str, Any]]:
214+
try:
215+
return _cache[key]
216+
except KeyError:
217+
return None
218+
219+
220+
def cache_set(key: str, value: Dict[str, Any]) -> None:
221+
"""Store value in cache."""
222+
_cache[key] = value

backend/app/classification/classification_router.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1+
import asyncio
12
import logging
3+
import json
24
from typing import Dict, Any
35
from langchain_google_genai import ChatGoogleGenerativeAI
46
from langchain_core.messages import HumanMessage
57
from app.core.config import settings
68
from .prompt import DEVREL_TRIAGE_PROMPT
9+
from app.classification.cache_helpers import (
10+
normalize_message,
11+
is_simple_message,
12+
get_cached_by_normalized,
13+
set_cached_by_normalized,
14+
metrics,
15+
MAX_MESSAGE_LENGTH,
16+
)
717

818
logger = logging.getLogger(__name__)
919

20+
# Limit concurrent LLM calls to prevent rate limiting and cost explosions
21+
_LLM_SEMAPHORE = asyncio.Semaphore(10)
22+
23+
1024
class ClassificationRouter:
1125
"""Simple DevRel triage - determines if message needs DevRel assistance"""
1226

@@ -20,28 +34,59 @@ def __init__(self, llm_client=None):
2034
async def should_process_message(self, message: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
2135
"""Simple triage: Does this message need DevRel assistance?"""
2236
try:
37+
# Early return for oversized messages to prevent DoS
38+
if len(message) > MAX_MESSAGE_LENGTH:
39+
logger.warning(f"Message exceeds max length ({len(message)} > {MAX_MESSAGE_LENGTH}), using fallback")
40+
return self._fallback_triage(message[:MAX_MESSAGE_LENGTH])
41+
42+
metrics["total"] += 1
43+
normalized = normalize_message(message)
44+
45+
# fast-path: simple pattern match (no LLM)
46+
simple = is_simple_message(normalized)
47+
48+
if simple is not None:
49+
metrics["skipped_llm"] += 1
50+
return simple
51+
52+
# cache lookup (include a light context fingerprint if present)
53+
ctx_id = None
54+
if context:
55+
ctx_id = context.get("channel_id") or context.get("thread_id") or ""
56+
if not ctx_id:
57+
ctx_id = None
58+
59+
cached = get_cached_by_normalized(normalized, ctx_id, settings.classification_agent_model, {"temperature": 0.1})
60+
if cached is not None:
61+
metrics["cache_hits"] += 1
62+
return cached
63+
64+
metrics["cache_misses"] += 1
65+
2366
triage_prompt = DEVREL_TRIAGE_PROMPT.format(
2467
message=message,
2568
context=context or 'No additional context'
2669
)
2770

28-
response = await self.llm.ainvoke([HumanMessage(content=triage_prompt)])
29-
71+
# Use semaphore to limit concurrent LLM calls
72+
async with _LLM_SEMAPHORE:
73+
response = await self.llm.ainvoke([HumanMessage(content=triage_prompt)])
3074
response_text = response.content.strip()
75+
3176
if '{' in response_text:
3277
json_start = response_text.find('{')
3378
json_end = response_text.rfind('}') + 1
3479
json_str = response_text[json_start:json_end]
35-
36-
import json
3780
result = json.loads(json_str)
3881

39-
return {
82+
payload = {
4083
"needs_devrel": result.get("needs_devrel", True),
4184
"priority": result.get("priority", "medium"),
4285
"reasoning": result.get("reasoning", "LLM classification"),
4386
"original_message": message
4487
}
88+
set_cached_by_normalized(normalized, ctx_id, settings.classification_agent_model, {"temperature": 0.1}, payload)
89+
return payload
4590

4691
return self._fallback_triage(message)
4792

0 commit comments

Comments
 (0)