feat(core): Support RAG chat flow (#1185)

This commit is contained in:
Fangyin Cheng
2024-02-23 11:44:44 +08:00
committed by GitHub
parent 21682575f5
commit e0986198a6
9 changed files with 134 additions and 54 deletions

View File

@@ -370,22 +370,16 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return FlowCategory.COMMON
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
if is_class:
return (
obj == str
or obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
output_obj == str
or output_obj == CommonLLMHttpResponseBody
or output_obj == ModelOutput
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator:
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)
return isinstance(output_obj, chat_types)
async def _chat_with_dag_task(
@@ -439,29 +433,50 @@ async def _chat_with_dag_task(
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
str_msg = ""
should_return = False
if isinstance(output, str):
if output.strip():
yield output
str_msg = output
elif isinstance(output, ModelOutput):
if output.error_code != 0:
str_msg = f"[SERVER_ERROR]{output.text}"
should_return = True
else:
str_msg = output.text
else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
str_msg = (
f"[SERVER_ERROR]The output is not a valid format"
f"({type(output)})"
)
should_return = True
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"
if should_return:
return
else:
result = await task.call(request)
str_msg = ""
if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n"
str_msg = "[SERVER_ERROR]The result is None!"
elif isinstance(result, str):
yield f"data:{result}\n\n"
str_msg = result
elif isinstance(result, ModelOutput):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
str_msg = f"[SERVER_ERROR]{result.text}"
else:
yield f"data:{result.text}\n\n"
str_msg = result.text
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
str_msg = f"[SERVER_ERROR]{result.text}"
else:
yield f"data:{result.text}\n\n"
str_msg = result.text
elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
str_msg = json.dumps(result, ensure_ascii=False)
else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"