Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 109 additions & 38 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import base64
import json
import os
from io import BytesIO
from typing import List, Union

import requests
from fastapi import File, Request, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from PIL import Image

from ..proto.api_protocol import (
Expand Down Expand Up @@ -837,6 +838,9 @@ def parser_input(data, TypeClass, key):


class MultimodalQnAGateway(Gateway):
asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001))
asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port))

def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999):
self.lvm_megaservice = lvm_megaservice
super().__init__(
Expand All @@ -851,7 +855,10 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0",
# this overrides _handle_message method of Gateway
def _handle_message(self, messages):
images = []
audios = []
b64_types = {}
messages_dicts = []
decoded_audio_input = ""
if isinstance(messages, str):
prompt = messages
else:
Expand All @@ -865,16 +872,26 @@ def _handle_message(self, messages):
system_prompt = message["content"]
elif msg_role == "user":
if type(message["content"]) == list:
# separate each media type and store accordingly
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
audios = [item["audio"] for item in message["content"] if item["type"] == "audio"]
if audios:
# translate audio to text. From this point forward, audio is treated like text
decoded_audio_input = self.convert_audio_to_text(audios)
b64_types["audio"] = decoded_audio_input

if text and not audios and not image_list:
messages_dict[msg_role] = text
elif audios and not text and not image_list:
messages_dict[msg_role] = decoded_audio_input
else:
messages_dict[msg_role] = (text, decoded_audio_input, image_list)

else:
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
Expand All @@ -886,55 +903,84 @@ def _handle_message(self, messages):

if system_prompt:
prompt = system_prompt + "\n"
for messages_dict in messages_dicts:
for i, (role, message) in enumerate(messages_dict.items()):
for i, messages_dict in enumerate(messages_dicts):
for role, message in messages_dict.items():
if isinstance(message, tuple):
text, image_list = message
text, decoded_audio_input, image_list = message
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if text:
prompt += text + "\n"
elif decoded_audio_input:
prompt += decoded_audio_input + "\n"
else:
if text:
prompt += role.upper() + ": " + text + "\n"
elif decoded_audio_input:
prompt += role.upper() + ": " + decoded_audio_input + "\n"
else:
prompt += role.upper() + ":"
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img

images.append(img_b64_str)
else:
if image_list:
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img

images.append(img_b64_str)

elif isinstance(message, str):
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if message:
prompt += role.upper() + ": " + message + "\n"
prompt += message + "\n"
else:
if message:
prompt += role.upper() + ": " + message + "\n"
else:
prompt += role.upper() + ":"

if images:
return prompt, images
b64_types["image"] = images

# If the query has multiple media types, return all types
if prompt and b64_types:
return prompt, b64_types
else:
return prompt

def convert_audio_to_text(self, audio):
# translate audio to text by passing in dictionary to ASR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment quirky! dictionary is a data type here but can get mixed with the English word dictionary (word meanings)

if isinstance(audio, dict):
input_dict = {"byte_str": audio["audio"][0]}
else:
input_dict = {"byte_str": audio[0]}

response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should proxies be read from some environment variable for a more general solution?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is setting proxies in the first place, shouldn't those be set well before this point?


if response.status_code != 200:
return JSONResponse(
status_code=503, content={"message": "Unable to convert audio to text. {}".format(response.text)}
)

response = response.json()
return response["query"]

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = bool(data.get("stream", False))
Expand All @@ -943,16 +989,35 @@ async def handle_request(self, request: Request):
stream_opt = False
chat_request = ChatCompletionRequest.model_validate(data)
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
prompt_and_image = self._handle_message(chat_request.messages)
if isinstance(prompt_and_image, tuple):
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
prompt, images = prompt_and_image
num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1
messages = self._handle_message(chat_request.messages)
decoded_audio_input = ""

if num_messages > 1:
# This is a follow up query, go to LVM
cur_megaservice = self.lvm_megaservice
initial_inputs = {"prompt": prompt, "image": images[0]}
if isinstance(messages, tuple):
prompt, b64_types = messages
if "audio" in b64_types:
# for metadata storage purposes
decoded_audio_input = b64_types["audio"]
if "image" in b64_types:
initial_inputs = {"prompt": prompt, "image": b64_types["image"][0]}
else:
initial_inputs = {"prompt": prompt, "image": ""}
else:
prompt = messages
initial_inputs = {"prompt": prompt, "image": ""}
else:
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
prompt = prompt_and_image
# This is the first query. Ignore image input
cur_megaservice = self.megaservice
if isinstance(messages, tuple):
prompt, b64_types = messages
if "audio" in b64_types:
# for metadata storage purposes
decoded_audio_input = b64_types["audio"]
else:
prompt = messages
initial_inputs = {"text": prompt}

parameters = LLMParams(
Expand Down Expand Up @@ -985,18 +1050,24 @@ async def handle_request(self, request: Request):
if "text" in result_dict[last_node].keys():
response = result_dict[last_node]["text"]
else:
# text in not response message
# text is not in response message
# something wrong, for example due to empty retrieval results
if "detail" in result_dict[last_node].keys():
response = result_dict[last_node]["detail"]
else:
response = "The server fail to generate answer to your query!"
response = "The server failed to generate an answer to your query!"
if "metadata" in result_dict[last_node].keys():
# from retrieval results
metadata = result_dict[last_node]["metadata"]
if decoded_audio_input:
metadata["audio"] = decoded_audio_input
else:
# follow-up question, no retrieval
metadata = None
if decoded_audio_input:
metadata = {"audio": decoded_audio_input}
else:
metadata = None

choices = []
usage = UsageInfo()
choices.append(
Expand Down
31 changes: 31 additions & 0 deletions tests/cores/mega/test_multimodalqna_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
# SPDX-License-Identifier: Apache-2.0

import json
import os
import unittest
from typing import Union

import requests
from fastapi import Request

os.environ["ASR_SERVICE_PORT"] = "8086"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this overrides environment, instead of taking the value from environment?


from comps import (
Base64ByteStrDoc,
EmbedDoc,
EmbedMultimodalDoc,
LLMParamsDoc,
LVMDoc,
LVMSearchedMultimodalDoc,
MultimodalDoc,
Expand Down Expand Up @@ -72,15 +77,25 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc:
return res


@register_microservice(name="asr", host="0.0.0.0", port=8086, endpoint="/v1/audio/transcriptions")
async def asr_add(request: Base64ByteStrDoc) -> LLMParamsDoc:
req = request.model_dump_json()
res = {}
res["query"] = "you"
return res


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
cls.mm_embedding = opea_microservices["mm_embedding"]
cls.mm_retriever = opea_microservices["mm_retriever"]
cls.lvm = opea_microservices["lvm"]
cls.asr = opea_microservices["asr"]
cls.mm_embedding.start()
cls.mm_retriever.start()
cls.lvm.start()
cls.asr.start()

cls.service_builder = ServiceOrchestrator()

Expand All @@ -100,6 +115,7 @@ def tearDownClass(cls):
cls.mm_embedding.stop()
cls.mm_retriever.stop()
cls.lvm.stop()
cls.asr.stop()
cls.gateway.stop()

async def test_service_builder_schedule(self):
Expand Down Expand Up @@ -181,6 +197,21 @@ def test_handle_message_with_system_prompt(self):
prompt, images = self.gateway._handle_message(messages)
self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n")

def test_handle_message_with_audio(self):
messages = [
{"role": "user", "content": [{"type": "text", "text": "hello, "}]},
{"role": "assistant", "content": "opea project! "},
{
"role": "user",
"content": [
{"type": "audio", "audio": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"}
],
},
]
prompt, b64_types = self.gateway._handle_message(messages)
self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: you\n")
self.assertEqual(b64_types, {"audio": "you"})

async def test_handle_request(self):
json_data = {
"messages": [
Expand Down