Skip to content

Commit 241f1c2

Browse files
feat: context compress (#4322)
* feat: context compressor Co-authored-by: kawayiYokami <289104862@qq.com> * Add comprehensive tests for ContextManager and ContextTruncator - Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling. - Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving. - Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages. * feat: add MockProvider for LLM compression tests * chore: remove lock * ruff fix * fix * perf * feat: enhance context compression with token tracking and logging * feat: update logging for context compression trigger * feat: implement context compression logic with dynamic threshold and token tracking * fix: reorder import statements for consistency * feat: add token_usage tracking to conversations and update related processing logic --------- Co-authored-by: kawayiYokami <289104862@qq.com>
1 parent 3615b7d commit 241f1c2

21 files changed

Lines changed: 2184 additions & 100 deletions

File tree

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
2+
3+
from ..message import Message
4+
5+
if TYPE_CHECKING:
6+
from astrbot import logger
7+
else:
8+
try:
9+
from astrbot import logger
10+
except ImportError:
11+
import logging
12+
13+
logger = logging.getLogger("astrbot")
14+
15+
if TYPE_CHECKING:
16+
from astrbot.core.provider.provider import Provider
17+
18+
from ..context.truncator import ContextTruncator
19+
20+
21+
@runtime_checkable
22+
class ContextCompressor(Protocol):
23+
"""
24+
Protocol for context compressors.
25+
Provides an interface for compressing message lists.
26+
"""
27+
28+
def should_compress(
29+
self, messages: list[Message], current_tokens: int, max_tokens: int
30+
) -> bool:
31+
"""Check if compression is needed.
32+
33+
Args:
34+
messages: The message list to evaluate.
35+
current_tokens: The current token count.
36+
max_tokens: The maximum allowed tokens for the model.
37+
38+
Returns:
39+
True if compression is needed, False otherwise.
40+
"""
41+
...
42+
43+
async def __call__(self, messages: list[Message]) -> list[Message]:
44+
"""Compress the message list.
45+
46+
Args:
47+
messages: The original message list.
48+
49+
Returns:
50+
The compressed message list.
51+
"""
52+
...
53+
54+
55+
class TruncateByTurnsCompressor:
56+
"""Truncate by turns compressor implementation.
57+
Truncates the message list by removing older turns.
58+
"""
59+
60+
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
61+
"""Initialize the truncate by turns compressor.
62+
63+
Args:
64+
truncate_turns: The number of turns to remove when truncating (default: 1).
65+
compression_threshold: The compression trigger threshold (default: 0.82).
66+
"""
67+
self.truncate_turns = truncate_turns
68+
self.compression_threshold = compression_threshold
69+
70+
def should_compress(
71+
self, messages: list[Message], current_tokens: int, max_tokens: int
72+
) -> bool:
73+
"""Check if compression is needed.
74+
75+
Args:
76+
messages: The message list to evaluate.
77+
current_tokens: The current token count.
78+
max_tokens: The maximum allowed tokens.
79+
80+
Returns:
81+
True if compression is needed, False otherwise.
82+
"""
83+
if max_tokens <= 0 or current_tokens <= 0:
84+
return False
85+
usage_rate = current_tokens / max_tokens
86+
return usage_rate > self.compression_threshold
87+
88+
async def __call__(self, messages: list[Message]) -> list[Message]:
89+
truncator = ContextTruncator()
90+
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
91+
messages,
92+
drop_turns=self.truncate_turns,
93+
)
94+
return truncated_messages
95+
96+
97+
def split_history(
98+
messages: list[Message], keep_recent: int
99+
) -> tuple[list[Message], list[Message], list[Message]]:
100+
"""Split the message list into system messages, messages to summarize, and recent messages.
101+
102+
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
103+
104+
Args:
105+
messages: The original message list.
106+
keep_recent: The number of latest messages to keep.
107+
108+
Returns:
109+
tuple: (system_messages, messages_to_summarize, recent_messages)
110+
"""
111+
# keep the system messages
112+
first_non_system = 0
113+
for i, msg in enumerate(messages):
114+
if msg.role != "system":
115+
first_non_system = i
116+
break
117+
118+
system_messages = messages[:first_non_system]
119+
non_system_messages = messages[first_non_system:]
120+
121+
if len(non_system_messages) <= keep_recent:
122+
return system_messages, [], non_system_messages
123+
124+
# Find the split point, ensuring recent_messages starts with a user message
125+
# This maintains complete conversation turns
126+
split_index = len(non_system_messages) - keep_recent
127+
128+
# Search backward from split_index to find the first user message
129+
# This ensures recent_messages starts with a user message (complete turn)
130+
while split_index > 0 and non_system_messages[split_index].role != "user":
131+
# TODO: +=1 or -=1 ? calculate by tokens
132+
split_index -= 1
133+
134+
# If we couldn't find a user message, keep all messages as recent
135+
if split_index == 0:
136+
return system_messages, [], non_system_messages
137+
138+
messages_to_summarize = non_system_messages[:split_index]
139+
recent_messages = non_system_messages[split_index:]
140+
141+
return system_messages, messages_to_summarize, recent_messages
142+
143+
144+
class LLMSummaryCompressor:
145+
"""LLM-based summary compressor.
146+
Uses LLM to summarize the old conversation history, keeping the latest messages.
147+
"""
148+
149+
def __init__(
150+
self,
151+
provider: "Provider",
152+
keep_recent: int = 4,
153+
instruction_text: str | None = None,
154+
compression_threshold: float = 0.82,
155+
):
156+
"""Initialize the LLM summary compressor.
157+
158+
Args:
159+
provider: The LLM provider instance.
160+
keep_recent: The number of latest messages to keep (default: 4).
161+
instruction_text: Custom instruction for summary generation.
162+
compression_threshold: The compression trigger threshold (default: 0.82).
163+
"""
164+
self.provider = provider
165+
self.keep_recent = keep_recent
166+
self.compression_threshold = compression_threshold
167+
168+
self.instruction_text = instruction_text or (
169+
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
170+
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
171+
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
172+
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
173+
"4. Write the summary in the user's language.\n"
174+
)
175+
176+
def should_compress(
177+
self, messages: list[Message], current_tokens: int, max_tokens: int
178+
) -> bool:
179+
"""Check if compression is needed.
180+
181+
Args:
182+
messages: The message list to evaluate.
183+
current_tokens: The current token count.
184+
max_tokens: The maximum allowed tokens.
185+
186+
Returns:
187+
True if compression is needed, False otherwise.
188+
"""
189+
if max_tokens <= 0 or current_tokens <= 0:
190+
return False
191+
usage_rate = current_tokens / max_tokens
192+
return usage_rate > self.compression_threshold
193+
194+
async def __call__(self, messages: list[Message]) -> list[Message]:
195+
"""Use LLM to generate a summary of the conversation history.
196+
197+
Process:
198+
1. Divide messages: keep the system message and the latest N messages.
199+
2. Send the old messages + the instruction message to the LLM.
200+
3. Reconstruct the message list: [system message, summary message, latest messages].
201+
"""
202+
if len(messages) <= self.keep_recent + 1:
203+
return messages
204+
205+
system_messages, messages_to_summarize, recent_messages = split_history(
206+
messages, self.keep_recent
207+
)
208+
209+
if not messages_to_summarize:
210+
return messages
211+
212+
# build payload
213+
instruction_message = Message(role="user", content=self.instruction_text)
214+
llm_payload = messages_to_summarize + [instruction_message]
215+
216+
# generate summary
217+
try:
218+
response = await self.provider.text_chat(contexts=llm_payload)
219+
summary_content = response.completion_text
220+
except Exception as e:
221+
logger.error(f"Failed to generate summary: {e}")
222+
return messages
223+
224+
# build result
225+
result = []
226+
result.extend(system_messages)
227+
228+
result.append(
229+
Message(
230+
role="user",
231+
content=f"Our previous history conversation summary: {summary_content}",
232+
)
233+
)
234+
result.append(
235+
Message(
236+
role="assistant",
237+
content="Acknowledged the summary of our previous conversation history.",
238+
)
239+
)
240+
241+
result.extend(recent_messages)
242+
243+
return result
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from dataclasses import dataclass
2+
from typing import TYPE_CHECKING
3+
4+
from .compressor import ContextCompressor
5+
from .token_counter import TokenCounter
6+
7+
if TYPE_CHECKING:
8+
from astrbot.core.provider.provider import Provider
9+
10+
11+
@dataclass
12+
class ContextConfig:
13+
"""Context configuration class."""
14+
15+
max_context_tokens: int = 0
16+
"""Maximum number of context tokens. <= 0 means no limit."""
17+
enforce_max_turns: int = -1 # -1 means no limit
18+
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
19+
truncate_turns: int = 1
20+
"""Number of conversation turns to discard at once when truncation is triggered.
21+
Two processes will use this value:
22+
23+
1. Enforce max turns truncation.
24+
2. Truncation by turns compression strategy.
25+
"""
26+
llm_compress_instruction: str | None = None
27+
"""Instruction prompt for LLM-based compression."""
28+
llm_compress_keep_recent: int = 0
29+
"""Number of recent messages to keep during LLM-based compression."""
30+
llm_compress_provider: "Provider | None" = None
31+
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
32+
custom_token_counter: TokenCounter | None = None
33+
"""Custom token counting method. If None, the default method is used."""
34+
custom_compressor: ContextCompressor | None = None
35+
"""Custom context compression method. If None, the default method is used."""

0 commit comments

Comments
 (0)