mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 14:11:14 +00:00
feat(core): Support RAG chat flow (#1185)
This commit is contained in:
@@ -370,22 +370,16 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
return FlowCategory.COMMON
|
||||
|
||||
|
||||
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
|
||||
try:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
except ImportError:
|
||||
OpenAIStreamingOutputOperator = None
|
||||
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
if is_class:
|
||||
return (
|
||||
obj == str
|
||||
or obj == CommonLLMHttpResponseBody
|
||||
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
|
||||
output_obj == str
|
||||
or output_obj == CommonLLMHttpResponseBody
|
||||
or output_obj == ModelOutput
|
||||
)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
if OpenAIStreamingOutputOperator:
|
||||
chat_types += (OpenAIStreamingOutputOperator,)
|
||||
return isinstance(obj, chat_types)
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
async def _chat_with_dag_task(
|
||||
@@ -439,29 +433,50 @@ async def _chat_with_dag_task(
|
||||
yield f"data:{full_text}\n\n"
|
||||
else:
|
||||
async for output in await task.call_stream(request):
|
||||
str_msg = ""
|
||||
should_return = False
|
||||
if isinstance(output, str):
|
||||
if output.strip():
|
||||
yield output
|
||||
str_msg = output
|
||||
elif isinstance(output, ModelOutput):
|
||||
if output.error_code != 0:
|
||||
str_msg = f"[SERVER_ERROR]{output.text}"
|
||||
should_return = True
|
||||
else:
|
||||
str_msg = output.text
|
||||
else:
|
||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||
return
|
||||
str_msg = (
|
||||
f"[SERVER_ERROR]The output is not a valid format"
|
||||
f"({type(output)})"
|
||||
)
|
||||
should_return = True
|
||||
if str_msg:
|
||||
str_msg = str_msg.replace("\n", "\\n")
|
||||
yield f"data:{str_msg}\n\n"
|
||||
if should_return:
|
||||
return
|
||||
else:
|
||||
result = await task.call(request)
|
||||
str_msg = ""
|
||||
if result is None:
|
||||
yield "data:[SERVER_ERROR]The result is None\n\n"
|
||||
str_msg = "[SERVER_ERROR]The result is None!"
|
||||
elif isinstance(result, str):
|
||||
yield f"data:{result}\n\n"
|
||||
str_msg = result
|
||||
elif isinstance(result, ModelOutput):
|
||||
if result.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||
str_msg = f"[SERVER_ERROR]{result.text}"
|
||||
else:
|
||||
yield f"data:{result.text}\n\n"
|
||||
str_msg = result.text
|
||||
elif isinstance(result, CommonLLMHttpResponseBody):
|
||||
if result.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||
str_msg = f"[SERVER_ERROR]{result.text}"
|
||||
else:
|
||||
yield f"data:{result.text}\n\n"
|
||||
str_msg = result.text
|
||||
elif isinstance(result, dict):
|
||||
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
str_msg = json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
|
||||
str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"
|
||||
|
||||
if str_msg:
|
||||
str_msg = str_msg.replace("\n", "\\n")
|
||||
yield f"data:{str_msg}\n\n"
|
||||
|
Reference in New Issue
Block a user