feat(core): Support more chat flows (#1180)

This commit is contained in:
Fangyin Cheng
2024-02-22 12:19:04 +08:00
committed by GitHub
parent 16fa68d4f2
commit ab5e1c7ea1
10 changed files with 175 additions and 55 deletions

View File

@@ -1,6 +1,7 @@
import json
import logging
import traceback
from typing import List, Optional, cast
from typing import Any, List, Optional, cast
from fastapi import HTTPException
@@ -14,6 +15,7 @@ from dbgpt.core.awel import (
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
@@ -276,12 +278,39 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""
return self.dao.get_list_page(request, page, page_size)
async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
async def chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
try:
async for output in self._call_chat_flow(flow_uid, request, incremental):
yield output
except HTTPException as e:
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
async def _call_chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
flow = self.get({"uid": flow_uid})
if not flow:
@@ -291,18 +320,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise HTTPException(
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
)
if flow.flow_category != FlowCategory.CHAT_FLOW:
raise ValueError(f"Flow {flow_uid} is not a chat flow")
dag = self.dag_manager.dag_map[dag_id]
if (
flow.flow_category != FlowCategory.CHAT_FLOW
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
):
raise ValueError(f"Flow {flow_uid} is not a chat flow")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
if request.stream:
async for output in await end_node.call_stream(request):
yield output
else:
yield await end_node.call(request)
async for output in _chat_with_dag_task(end_node, request, incremental):
yield output
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category
@@ -335,9 +364,104 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if common_http_trigger and (
real_class == str or real_class == CommonLLMHttpResponseBody
):
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if is_class:
return (
obj == str
or obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator:
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)
async def _chat_with_dag_task(
task: BaseOperator,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the DAG task.
Args:
task (BaseOperator): The task
request (CommonLLMHttpRequestBody): The request
"""
if request.stream and task.streaming_operator:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if incremental:
async for output in await task.call_stream(request):
yield output
else:
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
)
previous_text = ""
async for output in await task.call_stream(request):
if not isinstance(output, str):
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
if output == "data: [DONE]\n\n":
return
json_data = "".join(output.split("data: ")[1:])
dict_data = json.loads(json_data)
if "choices" not in dict_data:
error_msg = dict_data.get("text", "Unknown error")
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
return
choices = dict_data["choices"]
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
previous_text += delta_data.delta.content
if previous_text:
full_text = previous_text.replace("\n", "\\n")
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
if isinstance(output, str):
if output.strip():
yield output
else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
else:
result = await task.call(request)
if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n"
elif isinstance(result, str):
yield f"data:{result}\n\n"
elif isinstance(result, ModelOutput):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"