Skip to content

Commit 5aba3b2

Browse files
Support Long context for DocSum (#981)
* docsum four Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * support 4 modes for docsum Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * fix Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * fix bug Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine for docsum tgi Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * add docsum for ut and vllm Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * fix bug Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * fix bug Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ut bug Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * fix ut bug Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * set default value Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> --------- Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f3aaaeb commit 5aba3b2

File tree

10 files changed

+439
-76
lines changed

10 files changed

+439
-76
lines changed

comps/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
PIIResponseDoc,
3939
Audio2text,
4040
DocSumDoc,
41+
DocSumLLMParams,
4142
)
4243

4344
# Constants

comps/cores/proto/docarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ def chat_template_must_contain_variables(cls, v):
212212
return v
213213

214214

215+
class DocSumLLMParams(LLMParamsDoc):
216+
summary_type: str = "stuff" # can be "truncate", "map_reduce", "refine"
217+
chunk_size: int = -1
218+
chunk_overlap: int = -1
219+
220+
215221
class LLMParams(BaseDoc):
216222
model: Optional[str] = None
217223
max_tokens: int = 1024

comps/llms/summarization/tgi/langchain/README.md

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ In order to start TGI and LLM services, you need to setup the following environm
4848
export HF_TOKEN=${your_hf_api_token}
4949
export TGI_LLM_ENDPOINT="http://${your_ip}:8008"
5050
export LLM_MODEL_ID=${your_hf_llm_model}
51+
export MAX_INPUT_TOKENS=2048
52+
export MAX_TOTAL_TOKENS=4096
5153
```
5254

55+
Please make sure MAX_TOTAL_TOKENS should be larger than (MAX_INPUT_TOKENS + max_new_tokens + 50), 50 is reserved prompt length.
56+
5357
### 2.2 Build Docker Image
5458

5559
```bash
@@ -67,7 +71,7 @@ You can choose one as needed.
6771
### 2.3 Run Docker with CLI (Option A)
6872

6973
```bash
70-
docker run -d --name="llm-docsum-tgi-server" -p 9000:9000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT -e HF_TOKEN=$HF_TOKEN opea/llm-docsum-tgi:latest
74+
docker run -d --name="llm-docsum-tgi-server" -p 9000:9000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT -e HF_TOKEN=$HF_TOKEN -e MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} -e MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} opea/llm-docsum-tgi:latest
7175
```
7276

7377
### 2.4 Run Docker with Docker Compose (Option B)
@@ -88,6 +92,18 @@ curl http://${your_ip}:9000/v1/health_check\
8892

8993
### 3.2 Consume LLM Service
9094

95+
In DocSum microservice, except for basic LLM parameters, we also support several optimization parameters setting.
96+
97+
- "language": specify the language, can be "auto", "en", "zh", default is "auto"
98+
99+
If you want to deal with long context, can select suitable summary type, details in section 3.2.2.
100+
101+
- "summary_type": can be "stuff", "truncate", "map_reduce", "refine", default is "stuff"
102+
- "chunk_size": max token length for each chunk. Set to be different default value according to "summary_type".
103+
- "chunk_overlap": overlap token length between each chunk, default is 0.1\*chunk_size
104+
105+
#### 3.2.1 Basic usage
106+
91107
```bash
92108
# Enable streaming to receive a streaming response. By default, this is set to True.
93109
curl http://${your_ip}:9000/v1/chat/docsum \
@@ -101,9 +117,52 @@ curl http://${your_ip}:9000/v1/chat/docsum \
101117
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "streaming":false}' \
102118
-H 'Content-Type: application/json'
103119

104-
# Use Chinese mode. By default, language is set to "en"
120+
# Use Chinese mode
105121
curl http://${your_ip}:9000/v1/chat/docsum \
106122
-X POST \
107123
-d '{"query":"2024年9月26日,北京——今日,英特尔正式发布英特尔® 至强® 6性能核处理器(代号Granite Rapids),为AI、数据分析、科学计算等计算密集型业务提供卓越性能。", "max_tokens":32, "language":"zh", "streaming":false}' \
108124
-H 'Content-Type: application/json'
109125
```
126+
127+
#### 3.2.2 Long context summarization with "summary_type"
128+
129+
"summary_type" is set to be "stuff" by default, which will let LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context.
130+
131+
When deal with long context, you can set "summary_type" to one of "truncate", "map_reduce" and "refine" for better performance.
132+
133+
**summary_type=truncate**
134+
135+
Truncate mode will truncate the input text and keep only the first chunk, whose length is equal to `min(MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS)`
136+
137+
```bash
138+
curl http://${your_ip}:9000/v1/chat/docsum \
139+
-X POST \
140+
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "summary_type": "truncate", "chunk_size": 2000}' \
141+
-H 'Content-Type: application/json'
142+
```
143+
144+
**summary_type=map_reduce**
145+
146+
Map_reduce mode will split the inputs into multiple chunks, map each document to an individual summary, then consolidate those summaries into a single global summary. `streaming=True` is not allowed here.
147+
148+
In this mode, default `chunk_size` is set to be `min(MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS)`
149+
150+
```bash
151+
curl http://${your_ip}:9000/v1/chat/docsum \
152+
-X POST \
153+
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "summary_type": "map_reduce", "chunk_size": 2000, "streaming":false}' \
154+
-H 'Content-Type: application/json'
155+
```
156+
157+
**summary_type=refine**
158+
159+
Refin mode will split the inputs into multiple chunks, generate summary for the first one, then combine with the second, loops over every remaining chunks to get the final summary.
160+
161+
In this mode, default `chunk_size` is set to be `min(MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS)`.
162+
163+
```bash
164+
curl http://${your_ip}:9000/v1/chat/docsum \
165+
-X POST \
166+
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "summary_type": "refine", "chunk_size": 2000}' \
167+
-H 'Content-Type: application/json'
168+
```

comps/llms/summarization/tgi/langchain/docker_compose_llm.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ services:
1414
environment:
1515
HF_TOKEN: ${HF_TOKEN}
1616
shm_size: 1g
17-
command: --model-id ${LLM_MODEL_ID}
17+
command: --model-id ${LLM_MODEL_ID} --max-input-length ${MAX_INPUT_TOKENS} --max-total-tokens ${MAX_TOTAL_TOKENS}
1818
llm:
1919
image: opea/llm-docsum-tgi:latest
2020
container_name: llm-docsum-tgi-server
@@ -27,6 +27,9 @@ services:
2727
https_proxy: ${https_proxy}
2828
TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT}
2929
HUGGINGFACEHUB_API_TOKEN: ${HF_TOKEN}
30+
MAX_INPUT_TOKENS: ${MAX_INPUT_TOKENS}
31+
MAX_TOTAL_TOKENS: ${MAX_TOTAL_TOKENS}
32+
LLM_MODEL_ID: ${LLM_MODEL_ID}
3033
restart: unless-stopped
3134

3235
networks:

comps/llms/summarization/tgi/langchain/llm.py

Lines changed: 132 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import os
55

66
from fastapi.responses import StreamingResponse
7-
from huggingface_hub import AsyncInferenceClient
7+
from langchain.chains.summarize import load_summarize_chain
8+
from langchain.docstore.document import Document
89
from langchain.prompts import PromptTemplate
10+
from langchain_community.llms import HuggingFaceEndpoint
11+
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
12+
from transformers import AutoTokenizer
913

10-
from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
14+
from comps import CustomLogger, DocSumLLMParams, GeneratedDoc, ServiceType, opea_microservices, register_microservice
1115
from comps.cores.mega.utils import get_access_token
1216

1317
logger = CustomLogger("llm_docsum")
@@ -17,6 +21,9 @@
1721
TOKEN_URL = os.getenv("TOKEN_URL")
1822
CLIENTID = os.getenv("CLIENTID")
1923
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
24+
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", 2048))
25+
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", 4096))
26+
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "Intel/neural-chat-7b-v3-3")
2027

2128
templ_en = """Write a concise summary of the following:
2229
@@ -35,70 +42,169 @@
3542
概况:"""
3643

3744

45+
templ_refine_en = """\
46+
Your job is to produce a final summary.
47+
We have provided an existing summary up to a certain point: {existing_answer}
48+
We have the opportunity to refine the existing summary (only if needed) with some more context below.
49+
------------
50+
{text}
51+
------------
52+
Given the new context, refine the original summary.
53+
If the context isn't useful, return the original summary.\
54+
"""
55+
56+
templ_refine_zh = """\
57+
你的任务是生成一个最终摘要。
58+
我们已经提供了部分摘要:{existing_answer}
59+
如果有需要的话,可以通过以下更多上下文来完善现有摘要。
60+
------------
61+
{text}
62+
------------
63+
根据新上下文,完善原始摘要。
64+
如果上下文无用,则返回原始摘要。\
65+
"""
66+
67+
3868
@register_microservice(
3969
name="opea_service@llm_docsum",
4070
service_type=ServiceType.LLM,
4171
endpoint="/v1/chat/docsum",
4272
host="0.0.0.0",
4373
port=9000,
4474
)
45-
async def llm_generate(input: LLMParamsDoc):
75+
async def llm_generate(input: DocSumLLMParams):
4676
if logflag:
4777
logger.info(input)
78+
4879
if input.language in ["en", "auto"]:
4980
templ = templ_en
81+
templ_refine = templ_refine_en
5082
elif input.language in ["zh"]:
5183
templ = templ_zh
84+
templ_refine = templ_refine_zh
5285
else:
5386
raise NotImplementedError('Please specify the input language in "en", "zh", "auto"')
5487

55-
prompt_template = PromptTemplate.from_template(templ)
56-
prompt = prompt_template.format(text=input.query)
57-
88+
## Prompt
89+
PROMPT = PromptTemplate.from_template(templ)
90+
if input.summary_type == "refine":
91+
PROMPT_REFINE = PromptTemplate.from_template(templ_refine)
5892
if logflag:
5993
logger.info("After prompting:")
60-
logger.info(prompt)
94+
logger.info(PROMPT)
95+
if input.summary_type == "refine":
96+
logger.info(PROMPT_REFINE)
97+
98+
## Split text
99+
if input.summary_type == "stuff":
100+
text_splitter = CharacterTextSplitter()
101+
elif input.summary_type in ["truncate", "map_reduce", "refine"]:
102+
if input.summary_type == "refine":
103+
if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128:
104+
raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)")
105+
max_input_tokens = min(
106+
MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS
107+
) # 128 is reserved token length for prompt
108+
else:
109+
if MAX_TOTAL_TOKENS <= input.max_tokens + 50:
110+
raise RuntimeError("Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)")
111+
max_input_tokens = min(
112+
MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS
113+
) # 50 is reserved token length for prompt
114+
chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens
115+
chunk_overlap = input.chunk_overlap if input.chunk_overlap > 0 else int(0.1 * chunk_size)
116+
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
117+
tokenizer=tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
118+
)
119+
if logflag:
120+
logger.info(f"set chunk size to: {chunk_size}")
121+
logger.info(f"set chunk overlap to: {chunk_overlap}")
122+
else:
123+
raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"')
124+
texts = text_splitter.split_text(input.query)
125+
docs = [Document(page_content=t) for t in texts]
126+
if logflag:
127+
logger.info(f"Split input query into {len(docs)} chunks")
128+
logger.info(f"The character length of the first chunk is {len(texts[0])}")
61129

130+
## Access auth
62131
access_token = (
63132
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
64133
)
65-
headers = {}
134+
server_kwargs = {}
66135
if access_token:
67-
headers = {"Authorization": f"Bearer {access_token}"}
68-
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
69-
llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers)
136+
server_kwargs["headers"] = {"Authorization": f"Bearer {access_token}"}
70137

71-
text_generation = await llm.text_generation(
72-
prompt=prompt,
73-
stream=input.streaming,
138+
## LLM
139+
if input.streaming and input.summary_type == "map_reduce":
140+
logger.info("Map Reduce mode don't support streaming=True, set to streaming=False")
141+
input.streaming = False
142+
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
143+
llm = HuggingFaceEndpoint(
144+
endpoint_url=llm_endpoint,
74145
max_new_tokens=input.max_tokens,
75-
repetition_penalty=input.repetition_penalty,
76-
temperature=input.temperature,
77146
top_k=input.top_k,
78147
top_p=input.top_p,
79148
typical_p=input.typical_p,
149+
temperature=input.temperature,
150+
repetition_penalty=input.repetition_penalty,
151+
streaming=input.streaming,
152+
server_kwargs=server_kwargs,
80153
)
81154

155+
## LLM chain
156+
summary_type = input.summary_type
157+
if summary_type == "stuff":
158+
llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT)
159+
elif summary_type == "truncate":
160+
docs = [docs[0]]
161+
llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT)
162+
elif summary_type == "map_reduce":
163+
llm_chain = load_summarize_chain(
164+
llm=llm, map_prompt=PROMPT, combine_prompt=PROMPT, chain_type="map_reduce", return_intermediate_steps=True
165+
)
166+
elif summary_type == "refine":
167+
llm_chain = load_summarize_chain(
168+
llm=llm,
169+
question_prompt=PROMPT,
170+
refine_prompt=PROMPT_REFINE,
171+
chain_type="refine",
172+
return_intermediate_steps=True,
173+
)
174+
else:
175+
raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"')
176+
82177
if input.streaming:
83178

84179
async def stream_generator():
85-
chat_response = ""
86-
async for text in text_generation:
87-
chat_response += text
88-
chunk_repr = repr(text.encode("utf-8"))
180+
from langserve.serialization import WellKnownLCSerializer
181+
182+
_serializer = WellKnownLCSerializer()
183+
async for chunk in llm_chain.astream_log(docs):
184+
data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8")
89185
if logflag:
90-
logger.info(f"[ docsum - text_summarize ] chunk:{chunk_repr}")
91-
yield f"data: {chunk_repr}\n\n"
92-
if logflag:
93-
logger.info(f"[ docsum - text_summarize ] stream response: {chat_response}")
186+
logger.info(data)
187+
yield f"data: {data}\n\n"
94188
yield "data: [DONE]\n\n"
95189

96190
return StreamingResponse(stream_generator(), media_type="text/event-stream")
97191
else:
192+
response = await llm_chain.ainvoke(docs)
193+
194+
if input.summary_type in ["map_reduce", "refine"]:
195+
intermediate_steps = response["intermediate_steps"]
196+
if logflag:
197+
logger.info("intermediate_steps:")
198+
logger.info(intermediate_steps)
199+
200+
output_text = response["output_text"]
98201
if logflag:
99-
logger.info(text_generation)
100-
return GeneratedDoc(text=text_generation, prompt=input.query)
202+
logger.info("\n\noutput_text:")
203+
logger.info(output_text)
204+
205+
return GeneratedDoc(text=output_text, prompt=input.query)
101206

102207

103208
if __name__ == "__main__":
209+
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
104210
opea_microservices["opea_service@llm_docsum"].start()

comps/llms/summarization/tgi/langchain/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
docarray[full]
22
fastapi
3+
httpx==0.27.2
34
huggingface_hub
45
langchain #==0.1.12
56
langchain-huggingface

0 commit comments

Comments
 (0)