mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
feat(core): Support more chat flows (#1180)
This commit is contained in:
@@ -1 +1 @@
|
||||
version = "0.4.7"
|
||||
version = "0.5.0"
|
||||
|
@@ -366,11 +366,7 @@ async def chat_completions(
|
||||
context=flow_ctx,
|
||||
)
|
||||
return StreamingResponse(
|
||||
flow_stream_generator(
|
||||
flow_service.chat_flow(dialogue.select_param, flow_req),
|
||||
dialogue.incremental,
|
||||
dialogue.model_name,
|
||||
),
|
||||
flow_service.chat_flow(dialogue.select_param, flow_req),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -426,32 +422,6 @@ async def no_stream_generator(chat):
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def flow_stream_generator(func, incremental: bool, model_name: str):
|
||||
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||
previous_response = ""
|
||||
async for chunk in func:
|
||||
if chunk:
|
||||
msg = chunk.replace("\ufffd", "")
|
||||
if incremental:
|
||||
incremental_output = msg[len(previous_response) :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=stream_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# TODO generate an openai-compatible streaming responses
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
previous_response = msg
|
||||
await asyncio.sleep(0.02)
|
||||
if incremental:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
"""Generate streaming responses
|
||||
|
||||
|
@@ -632,16 +632,36 @@ class BaseMetadata(BaseResource):
|
||||
runnable_parameters: Dict[str, Any] = {}
|
||||
if not self.parameters or not view_parameters:
|
||||
return runnable_parameters
|
||||
if len(self.parameters) != len(view_parameters):
|
||||
view_required_parameters = {
|
||||
parameter.name: parameter
|
||||
for parameter in view_parameters
|
||||
if not parameter.optional
|
||||
}
|
||||
current_required_parameters = {
|
||||
parameter.name: parameter
|
||||
for parameter in self.parameters
|
||||
if not parameter.optional
|
||||
}
|
||||
current_parameters = {
|
||||
parameter.name: parameter for parameter in self.parameters
|
||||
}
|
||||
if len(view_required_parameters) < len(current_required_parameters):
|
||||
# TODO, skip the optional parameters.
|
||||
raise FlowParameterMetadataException(
|
||||
f"Parameters count not match. Expected {len(self.parameters)}, "
|
||||
f"Parameters count not match(current key: {self.id}). "
|
||||
f"Expected {len(self.parameters)}, "
|
||||
f"but got {len(view_parameters)} from JSON metadata."
|
||||
f"Required parameters: {current_required_parameters.keys()}, "
|
||||
f"but got {view_required_parameters.keys()}."
|
||||
)
|
||||
for i, parameter in enumerate(self.parameters):
|
||||
view_param = view_parameters[i]
|
||||
for view_param in view_parameters:
|
||||
view_param_key = view_param.name
|
||||
if view_param_key not in current_parameters:
|
||||
raise FlowParameterMetadataException(
|
||||
f"Parameter {view_param_key} not found in the metadata."
|
||||
)
|
||||
runnable_parameters.update(
|
||||
parameter.to_runnable_parameter(
|
||||
current_parameters[view_param_key].to_runnable_parameter(
|
||||
view_param.get_typed_value(), resources, key_to_resource_instance
|
||||
)
|
||||
)
|
||||
|
@@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
This class extends DAGNode by adding execution capabilities.
|
||||
"""
|
||||
|
||||
streaming_operator: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
|
@@ -10,6 +10,8 @@ from .base import BaseOperator
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
|
||||
|
||||
streaming_operator = True
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
call_data = curr_task_ctx.call_data
|
||||
@@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||
"""
|
||||
|
||||
streaming_operator = True
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||
|
@@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the messages."""
|
||||
if "system_message" not in values:
|
||||
raise ValueError("No system message")
|
||||
values["system_message"] = "You are a helpful AI Assistant."
|
||||
if "human_message" not in values:
|
||||
raise ValueError("No human message")
|
||||
values["human_message"] = "{user_input}"
|
||||
if "message_placeholder" not in values:
|
||||
raise ValueError("No message placeholder")
|
||||
values["message_placeholder"] = "chat_history"
|
||||
system_message = values.pop("system_message")
|
||||
human_message = values.pop("human_message")
|
||||
message_placeholder = values.pop("message_placeholder")
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -14,6 +15,7 @@ from dbgpt.core.awel import (
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
|
||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
@@ -276,12 +278,39 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
return self.dao.get_list_page(request, page, page_size)
|
||||
|
||||
async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
|
||||
async def chat_flow(
|
||||
self,
|
||||
flow_uid: str,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
incremental (bool): Whether to return the result incrementally
|
||||
"""
|
||||
try:
|
||||
async for output in self._call_chat_flow(flow_uid, request, incremental):
|
||||
yield output
|
||||
except HTTPException as e:
|
||||
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
|
||||
|
||||
async def _call_chat_flow(
|
||||
self,
|
||||
flow_uid: str,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
incremental (bool): Whether to return the result incrementally
|
||||
"""
|
||||
flow = self.get({"uid": flow_uid})
|
||||
if not flow:
|
||||
@@ -291,18 +320,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
|
||||
)
|
||||
if flow.flow_category != FlowCategory.CHAT_FLOW:
|
||||
raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
||||
dag = self.dag_manager.dag_map[dag_id]
|
||||
if (
|
||||
flow.flow_category != FlowCategory.CHAT_FLOW
|
||||
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
|
||||
):
|
||||
raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if request.stream:
|
||||
async for output in await end_node.call_stream(request):
|
||||
yield output
|
||||
else:
|
||||
yield await end_node.call(request)
|
||||
async for output in _chat_with_dag_task(end_node, request, incremental):
|
||||
yield output
|
||||
|
||||
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
||||
"""Parse the flow category
|
||||
@@ -335,9 +364,104 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
output = leaf_node.metadata.outputs[0]
|
||||
try:
|
||||
real_class = _get_type_cls(output.type_cls)
|
||||
if common_http_trigger and (
|
||||
real_class == str or real_class == CommonLLMHttpResponseBody
|
||||
):
|
||||
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
|
||||
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
|
||||
try:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
except ImportError:
|
||||
OpenAIStreamingOutputOperator = None
|
||||
if is_class:
|
||||
return (
|
||||
obj == str
|
||||
or obj == CommonLLMHttpResponseBody
|
||||
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
|
||||
)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
if OpenAIStreamingOutputOperator:
|
||||
chat_types += (OpenAIStreamingOutputOperator,)
|
||||
return isinstance(obj, chat_types)
|
||||
|
||||
|
||||
async def _chat_with_dag_task(
|
||||
task: BaseOperator,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the DAG task.
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The task
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
"""
|
||||
if request.stream and task.streaming_operator:
|
||||
try:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
except ImportError:
|
||||
OpenAIStreamingOutputOperator = None
|
||||
if incremental:
|
||||
async for output in await task.call_stream(request):
|
||||
yield output
|
||||
else:
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
)
|
||||
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
if not isinstance(output, str):
|
||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||
return
|
||||
if output == "data: [DONE]\n\n":
|
||||
return
|
||||
json_data = "".join(output.split("data: ")[1:])
|
||||
dict_data = json.loads(json_data)
|
||||
if "choices" not in dict_data:
|
||||
error_msg = dict_data.get("text", "Unknown error")
|
||||
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
|
||||
return
|
||||
choices = dict_data["choices"]
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
previous_text += delta_data.delta.content
|
||||
if previous_text:
|
||||
full_text = previous_text.replace("\n", "\\n")
|
||||
yield f"data:{full_text}\n\n"
|
||||
else:
|
||||
async for output in await task.call_stream(request):
|
||||
if isinstance(output, str):
|
||||
if output.strip():
|
||||
yield output
|
||||
else:
|
||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||
return
|
||||
else:
|
||||
result = await task.call(request)
|
||||
if result is None:
|
||||
yield "data:[SERVER_ERROR]The result is None\n\n"
|
||||
elif isinstance(result, str):
|
||||
yield f"data:{result}\n\n"
|
||||
elif isinstance(result, ModelOutput):
|
||||
if result.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||
else:
|
||||
yield f"data:{result.text}\n\n"
|
||||
elif isinstance(result, CommonLLMHttpResponseBody):
|
||||
if result.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||
else:
|
||||
yield f"data:{result.text}\n\n"
|
||||
elif isinstance(result, dict):
|
||||
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
|
||||
|
@@ -140,7 +140,7 @@ def update_repo(repo: str):
|
||||
logger.info(f"Repo '{repo}' is not a git repository.")
|
||||
return
|
||||
logger.info(f"Updating repo '{repo}'...")
|
||||
subprocess.run(["git", "pull"], check=True)
|
||||
subprocess.run(["git", "pull"], check=False)
|
||||
|
||||
|
||||
def install(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
# Upgrade To v0.5.0(Draft)
|
||||
# Upgrade To v0.5.0
|
||||
|
||||
## Overview
|
||||
|
||||
|
2
setup.py
2
setup.py
@@ -18,7 +18,7 @@ with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
||||
# If you modify the version, please modify the version in the following files:
|
||||
# dbgpt/_version.py
|
||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
|
||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0")
|
||||
|
||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||
LLAMA_CPP_GPU_ACCELERATION = (
|
||||
|
Reference in New Issue
Block a user