|
4 | 4 | import os |
5 | 5 |
|
6 | 6 | from fastapi.responses import StreamingResponse |
7 | | -from huggingface_hub import AsyncInferenceClient |
| 7 | +from langchain.chains.summarize import load_summarize_chain |
| 8 | +from langchain.docstore.document import Document |
8 | 9 | from langchain.prompts import PromptTemplate |
| 10 | +from langchain_community.llms import HuggingFaceEndpoint |
| 11 | +from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter |
| 12 | +from transformers import AutoTokenizer |
9 | 13 |
|
10 | | -from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice |
| 14 | +from comps import CustomLogger, DocSumLLMParams, GeneratedDoc, ServiceType, opea_microservices, register_microservice |
11 | 15 | from comps.cores.mega.utils import get_access_token |
12 | 16 |
|
13 | 17 | logger = CustomLogger("llm_docsum") |
|
17 | 21 | TOKEN_URL = os.getenv("TOKEN_URL") |
18 | 22 | CLIENTID = os.getenv("CLIENTID") |
19 | 23 | CLIENT_SECRET = os.getenv("CLIENT_SECRET") |
| 24 | +MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", 2048)) |
| 25 | +MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", 4096)) |
| 26 | +LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "Intel/neural-chat-7b-v3-3") |
20 | 27 |
|
21 | 28 | templ_en = """Write a concise summary of the following: |
22 | 29 |
|
|
35 | 42 | 概况:""" |
36 | 43 |
|
37 | 44 |
|
| 45 | +templ_refine_en = """\ |
| 46 | +Your job is to produce a final summary. |
| 47 | +We have provided an existing summary up to a certain point: {existing_answer} |
| 48 | +We have the opportunity to refine the existing summary (only if needed) with some more context below. |
| 49 | +------------ |
| 50 | +{text} |
| 51 | +------------ |
| 52 | +Given the new context, refine the original summary. |
| 53 | +If the context isn't useful, return the original summary.\ |
| 54 | +""" |
| 55 | + |
| 56 | +templ_refine_zh = """\ |
| 57 | +你的任务是生成一个最终摘要。 |
| 58 | +我们已经提供了部分摘要:{existing_answer} |
| 59 | +如果有需要的话,可以通过以下更多上下文来完善现有摘要。 |
| 60 | +------------ |
| 61 | +{text} |
| 62 | +------------ |
| 63 | +根据新上下文,完善原始摘要。 |
| 64 | +如果上下文无用,则返回原始摘要。\ |
| 65 | +""" |
| 66 | + |
| 67 | + |
38 | 68 | @register_microservice( |
39 | 69 | name="opea_service@llm_docsum", |
40 | 70 | service_type=ServiceType.LLM, |
41 | 71 | endpoint="/v1/chat/docsum", |
42 | 72 | host="0.0.0.0", |
43 | 73 | port=9000, |
44 | 74 | ) |
45 | | -async def llm_generate(input: LLMParamsDoc): |
| 75 | +async def llm_generate(input: DocSumLLMParams): |
46 | 76 | if logflag: |
47 | 77 | logger.info(input) |
| 78 | + |
48 | 79 | if input.language in ["en", "auto"]: |
49 | 80 | templ = templ_en |
| 81 | + templ_refine = templ_refine_en |
50 | 82 | elif input.language in ["zh"]: |
51 | 83 | templ = templ_zh |
| 84 | + templ_refine = templ_refine_zh |
52 | 85 | else: |
53 | 86 | raise NotImplementedError('Please specify the input language in "en", "zh", "auto"') |
54 | 87 |
|
55 | | - prompt_template = PromptTemplate.from_template(templ) |
56 | | - prompt = prompt_template.format(text=input.query) |
57 | | - |
| 88 | + ## Prompt |
| 89 | + PROMPT = PromptTemplate.from_template(templ) |
| 90 | + if input.summary_type == "refine": |
| 91 | + PROMPT_REFINE = PromptTemplate.from_template(templ_refine) |
58 | 92 | if logflag: |
59 | 93 | logger.info("After prompting:") |
60 | | - logger.info(prompt) |
| 94 | + logger.info(PROMPT) |
| 95 | + if input.summary_type == "refine": |
| 96 | + logger.info(PROMPT_REFINE) |
| 97 | + |
| 98 | + ## Split text |
| 99 | + if input.summary_type == "stuff": |
| 100 | + text_splitter = CharacterTextSplitter() |
| 101 | + elif input.summary_type in ["truncate", "map_reduce", "refine"]: |
| 102 | + if input.summary_type == "refine": |
| 103 | + if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128: |
| 104 | + raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)") |
| 105 | + max_input_tokens = min( |
| 106 | + MAX_TOTAL_TOKENS - 2 * input.max_tokens - 128, MAX_INPUT_TOKENS |
| 107 | + ) # 128 is reserved token length for prompt |
| 108 | + else: |
| 109 | + if MAX_TOTAL_TOKENS <= input.max_tokens + 50: |
| 110 | + raise RuntimeError("Please set MAX_TOTAL_TOKENS larger than max_tokens + 50)") |
| 111 | + max_input_tokens = min( |
| 112 | + MAX_TOTAL_TOKENS - input.max_tokens - 50, MAX_INPUT_TOKENS |
| 113 | + ) # 50 is reserved token length for prompt |
| 114 | + chunk_size = min(input.chunk_size, max_input_tokens) if input.chunk_size > 0 else max_input_tokens |
| 115 | + chunk_overlap = input.chunk_overlap if input.chunk_overlap > 0 else int(0.1 * chunk_size) |
| 116 | + text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
| 117 | + tokenizer=tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap |
| 118 | + ) |
| 119 | + if logflag: |
| 120 | + logger.info(f"set chunk size to: {chunk_size}") |
| 121 | + logger.info(f"set chunk overlap to: {chunk_overlap}") |
| 122 | + else: |
| 123 | + raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"') |
| 124 | + texts = text_splitter.split_text(input.query) |
| 125 | + docs = [Document(page_content=t) for t in texts] |
| 126 | + if logflag: |
| 127 | + logger.info(f"Split input query into {len(docs)} chunks") |
| 128 | + logger.info(f"The character length of the first chunk is {len(texts[0])}") |
61 | 129 |
|
| 130 | + ## Access auth |
62 | 131 | access_token = ( |
63 | 132 | get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None |
64 | 133 | ) |
65 | | - headers = {} |
| 134 | + server_kwargs = {} |
66 | 135 | if access_token: |
67 | | - headers = {"Authorization": f"Bearer {access_token}"} |
68 | | - llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") |
69 | | - llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers) |
| 136 | + server_kwargs["headers"] = {"Authorization": f"Bearer {access_token}"} |
70 | 137 |
|
71 | | - text_generation = await llm.text_generation( |
72 | | - prompt=prompt, |
73 | | - stream=input.streaming, |
| 138 | + ## LLM |
| 139 | + if input.streaming and input.summary_type == "map_reduce": |
| 140 | + logger.info("Map Reduce mode don't support streaming=True, set to streaming=False") |
| 141 | + input.streaming = False |
| 142 | + llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") |
| 143 | + llm = HuggingFaceEndpoint( |
| 144 | + endpoint_url=llm_endpoint, |
74 | 145 | max_new_tokens=input.max_tokens, |
75 | | - repetition_penalty=input.repetition_penalty, |
76 | | - temperature=input.temperature, |
77 | 146 | top_k=input.top_k, |
78 | 147 | top_p=input.top_p, |
79 | 148 | typical_p=input.typical_p, |
| 149 | + temperature=input.temperature, |
| 150 | + repetition_penalty=input.repetition_penalty, |
| 151 | + streaming=input.streaming, |
| 152 | + server_kwargs=server_kwargs, |
80 | 153 | ) |
81 | 154 |
|
| 155 | + ## LLM chain |
| 156 | + summary_type = input.summary_type |
| 157 | + if summary_type == "stuff": |
| 158 | + llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT) |
| 159 | + elif summary_type == "truncate": |
| 160 | + docs = [docs[0]] |
| 161 | + llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT) |
| 162 | + elif summary_type == "map_reduce": |
| 163 | + llm_chain = load_summarize_chain( |
| 164 | + llm=llm, map_prompt=PROMPT, combine_prompt=PROMPT, chain_type="map_reduce", return_intermediate_steps=True |
| 165 | + ) |
| 166 | + elif summary_type == "refine": |
| 167 | + llm_chain = load_summarize_chain( |
| 168 | + llm=llm, |
| 169 | + question_prompt=PROMPT, |
| 170 | + refine_prompt=PROMPT_REFINE, |
| 171 | + chain_type="refine", |
| 172 | + return_intermediate_steps=True, |
| 173 | + ) |
| 174 | + else: |
| 175 | + raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"') |
| 176 | + |
82 | 177 | if input.streaming: |
83 | 178 |
|
84 | 179 | async def stream_generator(): |
85 | | - chat_response = "" |
86 | | - async for text in text_generation: |
87 | | - chat_response += text |
88 | | - chunk_repr = repr(text.encode("utf-8")) |
| 180 | + from langserve.serialization import WellKnownLCSerializer |
| 181 | + |
| 182 | + _serializer = WellKnownLCSerializer() |
| 183 | + async for chunk in llm_chain.astream_log(docs): |
| 184 | + data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8") |
89 | 185 | if logflag: |
90 | | - logger.info(f"[ docsum - text_summarize ] chunk:{chunk_repr}") |
91 | | - yield f"data: {chunk_repr}\n\n" |
92 | | - if logflag: |
93 | | - logger.info(f"[ docsum - text_summarize ] stream response: {chat_response}") |
| 186 | + logger.info(data) |
| 187 | + yield f"data: {data}\n\n" |
94 | 188 | yield "data: [DONE]\n\n" |
95 | 189 |
|
96 | 190 | return StreamingResponse(stream_generator(), media_type="text/event-stream") |
97 | 191 | else: |
| 192 | + response = await llm_chain.ainvoke(docs) |
| 193 | + |
| 194 | + if input.summary_type in ["map_reduce", "refine"]: |
| 195 | + intermediate_steps = response["intermediate_steps"] |
| 196 | + if logflag: |
| 197 | + logger.info("intermediate_steps:") |
| 198 | + logger.info(intermediate_steps) |
| 199 | + |
| 200 | + output_text = response["output_text"] |
98 | 201 | if logflag: |
99 | | - logger.info(text_generation) |
100 | | - return GeneratedDoc(text=text_generation, prompt=input.query) |
| 202 | + logger.info("\n\noutput_text:") |
| 203 | + logger.info(output_text) |
| 204 | + |
| 205 | + return GeneratedDoc(text=output_text, prompt=input.query) |
101 | 206 |
|
102 | 207 |
|
103 | 208 | if __name__ == "__main__": |
| 209 | + tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID) |
104 | 210 | opea_microservices["opea_service@llm_docsum"].start() |
0 commit comments