mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(core): Support more chat flows (#1180)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -14,6 +15,7 @@ from dbgpt.core.awel import (
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
|
||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
@@ -276,12 +278,39 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
return self.dao.get_list_page(request, page, page_size)
|
||||
|
||||
async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
|
||||
async def chat_flow(
|
||||
self,
|
||||
flow_uid: str,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
incremental (bool): Whether to return the result incrementally
|
||||
"""
|
||||
try:
|
||||
async for output in self._call_chat_flow(flow_uid, request, incremental):
|
||||
yield output
|
||||
except HTTPException as e:
|
||||
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
|
||||
|
||||
async def _call_chat_flow(
|
||||
self,
|
||||
flow_uid: str,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
incremental (bool): Whether to return the result incrementally
|
||||
"""
|
||||
flow = self.get({"uid": flow_uid})
|
||||
if not flow:
|
||||
@@ -291,18 +320,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
|
||||
)
|
||||
if flow.flow_category != FlowCategory.CHAT_FLOW:
|
||||
raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
||||
dag = self.dag_manager.dag_map[dag_id]
|
||||
if (
|
||||
flow.flow_category != FlowCategory.CHAT_FLOW
|
||||
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
|
||||
):
|
||||
raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if request.stream:
|
||||
async for output in await end_node.call_stream(request):
|
||||
yield output
|
||||
else:
|
||||
yield await end_node.call(request)
|
||||
async for output in _chat_with_dag_task(end_node, request, incremental):
|
||||
yield output
|
||||
|
||||
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
||||
"""Parse the flow category
|
||||
@@ -335,9 +364,104 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
output = leaf_node.metadata.outputs[0]
|
||||
try:
|
||||
real_class = _get_type_cls(output.type_cls)
|
||||
if common_http_trigger and (
|
||||
real_class == str or real_class == CommonLLMHttpResponseBody
|
||||
):
|
||||
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
|
||||
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
|
||||
try:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
except ImportError:
|
||||
OpenAIStreamingOutputOperator = None
|
||||
if is_class:
|
||||
return (
|
||||
obj == str
|
||||
or obj == CommonLLMHttpResponseBody
|
||||
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
|
||||
)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
if OpenAIStreamingOutputOperator:
|
||||
chat_types += (OpenAIStreamingOutputOperator,)
|
||||
return isinstance(obj, chat_types)
|
||||
|
||||
|
||||
async def _chat_with_dag_task(
|
||||
task: BaseOperator,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the DAG task.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The task
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
"""
|
||||
if request.stream and task.streaming_operator:
|
||||
try:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
except ImportError:
|
||||
OpenAIStreamingOutputOperator = None
|
||||
if incremental:
|
||||
async for output in await task.call_stream(request):
|
||||
yield output
|
||||
else:
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
)
|
||||
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
if not isinstance(output, str):
|
||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||
return
|
||||
if output == "data: [DONE]\n\n":
|
||||
return
|
||||
json_data = "".join(output.split("data: ")[1:])
|
||||
dict_data = json.loads(json_data)
|
||||
if "choices" not in dict_data:
|
||||
error_msg = dict_data.get("text", "Unknown error")
|
||||
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
|
||||
return
|
||||
choices = dict_data["choices"]
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
previous_text += delta_data.delta.content
|
||||
if previous_text:
|
||||
full_text = previous_text.replace("\n", "\\n")
|
||||
yield f"data:{full_text}\n\n"
|
||||
else:
|
||||
async for output in await task.call_stream(request):
|
||||
if isinstance(output, str):
|
||||
if output.strip():
|
||||
yield output
|
||||
else:
|
||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||
return
|
||||
else:
|
||||
result = await task.call(request)
|
||||
if result is None:
|
||||
yield "data:[SERVER_ERROR]The result is None\n\n"
|
||||
elif isinstance(result, str):
|
||||
yield f"data:{result}\n\n"
|
||||
elif isinstance(result, ModelOutput):
|
||||
if result.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||
else:
|
||||
yield f"data:{result.text}\n\n"
|
||||
elif isinstance(result, CommonLLMHttpResponseBody):
|
||||
if result.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||
else:
|
||||
yield f"data:{result.text}\n\n"
|
||||
elif isinstance(result, dict):
|
||||
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
|
||||
|
Reference in New Issue
Block a user