-
Notifications
You must be signed in to change notification settings - Fork 217
Adds audio querying to MultimodalQ&A gateway #974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae5437a
e1e5fde
70c54e1
6a71843
615459b
1753473
dcafe8d
fa47959
37826be
40d34db
6f2a753
a665c3c
4a5c8ea
d9ab567
75b135f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,14 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import base64 | ||
| import json | ||
| import os | ||
| from io import BytesIO | ||
| from typing import List, Union | ||
|
|
||
| import requests | ||
| from fastapi import File, Request, UploadFile | ||
| from fastapi.responses import StreamingResponse | ||
| from fastapi.responses import JSONResponse, StreamingResponse | ||
| from PIL import Image | ||
|
|
||
| from ..proto.api_protocol import ( | ||
|
|
@@ -837,6 +838,9 @@ def parser_input(data, TypeClass, key): | |
|
|
||
|
|
||
| class MultimodalQnAGateway(Gateway): | ||
| asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001)) | ||
| asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port)) | ||
|
|
||
| def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999): | ||
| self.lvm_megaservice = lvm_megaservice | ||
| super().__init__( | ||
|
|
@@ -851,7 +855,10 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", | |
| # this overrides _handle_message method of Gateway | ||
| def _handle_message(self, messages): | ||
| images = [] | ||
| audios = [] | ||
| b64_types = {} | ||
| messages_dicts = [] | ||
| decoded_audio_input = "" | ||
| if isinstance(messages, str): | ||
| prompt = messages | ||
| else: | ||
|
|
@@ -865,16 +872,26 @@ def _handle_message(self, messages): | |
| system_prompt = message["content"] | ||
| elif msg_role == "user": | ||
| if type(message["content"]) == list: | ||
| # separate each media type and store accordingly | ||
| text = "" | ||
| text_list = [item["text"] for item in message["content"] if item["type"] == "text"] | ||
| text += "\n".join(text_list) | ||
| image_list = [ | ||
| item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" | ||
| ] | ||
| if image_list: | ||
| messages_dict[msg_role] = (text, image_list) | ||
| else: | ||
| audios = [item["audio"] for item in message["content"] if item["type"] == "audio"] | ||
| if audios: | ||
| # translate audio to text. From this point forward, audio is treated like text | ||
| decoded_audio_input = self.convert_audio_to_text(audios) | ||
| b64_types["audio"] = decoded_audio_input | ||
|
|
||
| if text and not audios and not image_list: | ||
| messages_dict[msg_role] = text | ||
| elif audios and not text and not image_list: | ||
| messages_dict[msg_role] = decoded_audio_input | ||
| else: | ||
| messages_dict[msg_role] = (text, decoded_audio_input, image_list) | ||
|
|
||
| else: | ||
| messages_dict[msg_role] = message["content"] | ||
| messages_dicts.append(messages_dict) | ||
|
|
@@ -886,55 +903,84 @@ def _handle_message(self, messages): | |
|
|
||
| if system_prompt: | ||
| prompt = system_prompt + "\n" | ||
| for messages_dict in messages_dicts: | ||
| for i, (role, message) in enumerate(messages_dict.items()): | ||
| for i, messages_dict in enumerate(messages_dicts): | ||
| for role, message in messages_dict.items(): | ||
| if isinstance(message, tuple): | ||
| text, image_list = message | ||
| text, decoded_audio_input, image_list = message | ||
| if i == 0: | ||
| # do not add role for the very first message. | ||
| # this will be added by llava_server | ||
| if text: | ||
| prompt += text + "\n" | ||
| elif decoded_audio_input: | ||
| prompt += decoded_audio_input + "\n" | ||
| else: | ||
| if text: | ||
| prompt += role.upper() + ": " + text + "\n" | ||
| elif decoded_audio_input: | ||
| prompt += role.upper() + ": " + decoded_audio_input + "\n" | ||
| else: | ||
| prompt += role.upper() + ":" | ||
| for img in image_list: | ||
| # URL | ||
| if img.startswith("http://") or img.startswith("https://"): | ||
| response = requests.get(img) | ||
| image = Image.open(BytesIO(response.content)).convert("RGBA") | ||
| image_bytes = BytesIO() | ||
| image.save(image_bytes, format="PNG") | ||
| img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() | ||
| # Local Path | ||
| elif os.path.exists(img): | ||
| image = Image.open(img).convert("RGBA") | ||
| image_bytes = BytesIO() | ||
| image.save(image_bytes, format="PNG") | ||
| img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() | ||
| # Bytes | ||
| else: | ||
| img_b64_str = img | ||
|
|
||
| images.append(img_b64_str) | ||
| else: | ||
| if image_list: | ||
| for img in image_list: | ||
| # URL | ||
| if img.startswith("http://") or img.startswith("https://"): | ||
| response = requests.get(img) | ||
| image = Image.open(BytesIO(response.content)).convert("RGBA") | ||
| image_bytes = BytesIO() | ||
| image.save(image_bytes, format="PNG") | ||
| img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() | ||
| # Local Path | ||
| elif os.path.exists(img): | ||
| image = Image.open(img).convert("RGBA") | ||
| image_bytes = BytesIO() | ||
| image.save(image_bytes, format="PNG") | ||
| img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() | ||
| # Bytes | ||
| else: | ||
| img_b64_str = img | ||
|
|
||
| images.append(img_b64_str) | ||
|
|
||
| elif isinstance(message, str): | ||
| if i == 0: | ||
| # do not add role for the very first message. | ||
| # this will be added by llava_server | ||
| if message: | ||
| prompt += role.upper() + ": " + message + "\n" | ||
| prompt += message + "\n" | ||
| else: | ||
| if message: | ||
| prompt += role.upper() + ": " + message + "\n" | ||
| else: | ||
| prompt += role.upper() + ":" | ||
|
|
||
| if images: | ||
| return prompt, images | ||
| b64_types["image"] = images | ||
|
|
||
| # If the query has multiple media types, return all types | ||
| if prompt and b64_types: | ||
| return prompt, b64_types | ||
| else: | ||
| return prompt | ||
|
|
||
| def convert_audio_to_text(self, audio): | ||
| # translate audio to text by passing in dictionary to ASR | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment quirky! dictionary is a data type here but can get mixed with the English word dictionary (word meanings) |
||
| if isinstance(audio, dict): | ||
| input_dict = {"byte_str": audio["audio"][0]} | ||
| else: | ||
| input_dict = {"byte_str": audio[0]} | ||
|
|
||
| response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should proxies be read from some environment variable for a more general solution?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is setting proxies in the first place, shouldn't those be set well before this point? |
||
|
|
||
| if response.status_code != 200: | ||
| return JSONResponse( | ||
| status_code=503, content={"message": "Unable to convert audio to text. {}".format(response.text)} | ||
| ) | ||
|
|
||
| response = response.json() | ||
| return response["query"] | ||
|
|
||
| async def handle_request(self, request: Request): | ||
| data = await request.json() | ||
| stream_opt = bool(data.get("stream", False)) | ||
|
|
@@ -943,16 +989,35 @@ async def handle_request(self, request: Request): | |
| stream_opt = False | ||
| chat_request = ChatCompletionRequest.model_validate(data) | ||
| # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. | ||
| prompt_and_image = self._handle_message(chat_request.messages) | ||
| if isinstance(prompt_and_image, tuple): | ||
| # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") | ||
| prompt, images = prompt_and_image | ||
| num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 | ||
| messages = self._handle_message(chat_request.messages) | ||
| decoded_audio_input = "" | ||
|
|
||
| if num_messages > 1: | ||
| # This is a follow up query, go to LVM | ||
| cur_megaservice = self.lvm_megaservice | ||
| initial_inputs = {"prompt": prompt, "image": images[0]} | ||
| if isinstance(messages, tuple): | ||
| prompt, b64_types = messages | ||
| if "audio" in b64_types: | ||
| # for metadata storage purposes | ||
| decoded_audio_input = b64_types["audio"] | ||
| if "image" in b64_types: | ||
| initial_inputs = {"prompt": prompt, "image": b64_types["image"][0]} | ||
| else: | ||
| initial_inputs = {"prompt": prompt, "image": ""} | ||
| else: | ||
| prompt = messages | ||
| initial_inputs = {"prompt": prompt, "image": ""} | ||
| else: | ||
| # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") | ||
| prompt = prompt_and_image | ||
| # This is the first query. Ignore image input | ||
| cur_megaservice = self.megaservice | ||
| if isinstance(messages, tuple): | ||
| prompt, b64_types = messages | ||
| if "audio" in b64_types: | ||
| # for metadata storage purposes | ||
| decoded_audio_input = b64_types["audio"] | ||
| else: | ||
| prompt = messages | ||
| initial_inputs = {"text": prompt} | ||
|
|
||
| parameters = LLMParams( | ||
|
|
@@ -985,18 +1050,24 @@ async def handle_request(self, request: Request): | |
| if "text" in result_dict[last_node].keys(): | ||
| response = result_dict[last_node]["text"] | ||
| else: | ||
| # text in not response message | ||
| # text is not in response message | ||
| # something wrong, for example due to empty retrieval results | ||
| if "detail" in result_dict[last_node].keys(): | ||
| response = result_dict[last_node]["detail"] | ||
| else: | ||
| response = "The server fail to generate answer to your query!" | ||
| response = "The server failed to generate an answer to your query!" | ||
| if "metadata" in result_dict[last_node].keys(): | ||
| # from retrieval results | ||
| metadata = result_dict[last_node]["metadata"] | ||
| if decoded_audio_input: | ||
| metadata["audio"] = decoded_audio_input | ||
| else: | ||
| # follow-up question, no retrieval | ||
| metadata = None | ||
| if decoded_audio_input: | ||
| metadata = {"audio": decoded_audio_input} | ||
| else: | ||
| metadata = None | ||
|
|
||
| choices = [] | ||
| usage = UsageInfo() | ||
| choices.append( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,15 +2,20 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import json | ||
| import os | ||
| import unittest | ||
| from typing import Union | ||
|
|
||
| import requests | ||
| from fastapi import Request | ||
|
|
||
| os.environ["ASR_SERVICE_PORT"] = "8086" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this overrides environment, instead of taking the value from environment? |
||
|
|
||
| from comps import ( | ||
| Base64ByteStrDoc, | ||
| EmbedDoc, | ||
| EmbedMultimodalDoc, | ||
| LLMParamsDoc, | ||
| LVMDoc, | ||
| LVMSearchedMultimodalDoc, | ||
| MultimodalDoc, | ||
|
|
@@ -72,15 +77,25 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: | |
| return res | ||
|
|
||
|
|
||
| @register_microservice(name="asr", host="0.0.0.0", port=8086, endpoint="/v1/audio/transcriptions") | ||
| async def asr_add(request: Base64ByteStrDoc) -> LLMParamsDoc: | ||
| req = request.model_dump_json() | ||
| res = {} | ||
| res["query"] = "you" | ||
| return res | ||
|
|
||
|
|
||
| class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.mm_embedding = opea_microservices["mm_embedding"] | ||
| cls.mm_retriever = opea_microservices["mm_retriever"] | ||
| cls.lvm = opea_microservices["lvm"] | ||
| cls.asr = opea_microservices["asr"] | ||
| cls.mm_embedding.start() | ||
| cls.mm_retriever.start() | ||
| cls.lvm.start() | ||
| cls.asr.start() | ||
|
|
||
| cls.service_builder = ServiceOrchestrator() | ||
|
|
||
|
|
@@ -100,6 +115,7 @@ def tearDownClass(cls): | |
| cls.mm_embedding.stop() | ||
| cls.mm_retriever.stop() | ||
| cls.lvm.stop() | ||
| cls.asr.stop() | ||
| cls.gateway.stop() | ||
|
|
||
| async def test_service_builder_schedule(self): | ||
|
|
@@ -181,6 +197,21 @@ def test_handle_message_with_system_prompt(self): | |
| prompt, images = self.gateway._handle_message(messages) | ||
| self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") | ||
|
|
||
| def test_handle_message_with_audio(self): | ||
| messages = [ | ||
| {"role": "user", "content": [{"type": "text", "text": "hello, "}]}, | ||
| {"role": "assistant", "content": "opea project! "}, | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "audio", "audio": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"} | ||
| ], | ||
| }, | ||
| ] | ||
| prompt, b64_types = self.gateway._handle_message(messages) | ||
| self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: you\n") | ||
| self.assertEqual(b64_types, {"audio": "you"}) | ||
|
|
||
| async def test_handle_request(self): | ||
| json_data = { | ||
| "messages": [ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.