mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
feat(core): Support more chat flows (#1180)
This commit is contained in:
@@ -1 +1 @@
|
|||||||
version = "0.4.7"
|
version = "0.5.0"
|
||||||
|
@@ -366,11 +366,7 @@ async def chat_completions(
|
|||||||
context=flow_ctx,
|
context=flow_ctx,
|
||||||
)
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
flow_stream_generator(
|
flow_service.chat_flow(dialogue.select_param, flow_req),
|
||||||
flow_service.chat_flow(dialogue.select_param, flow_req),
|
|
||||||
dialogue.incremental,
|
|
||||||
dialogue.model_name,
|
|
||||||
),
|
|
||||||
headers=headers,
|
headers=headers,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
@@ -426,32 +422,6 @@ async def no_stream_generator(chat):
|
|||||||
yield f"data: {msg}\n\n"
|
yield f"data: {msg}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def flow_stream_generator(func, incremental: bool, model_name: str):
|
|
||||||
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
|
||||||
previous_response = ""
|
|
||||||
async for chunk in func:
|
|
||||||
if chunk:
|
|
||||||
msg = chunk.replace("\ufffd", "")
|
|
||||||
if incremental:
|
|
||||||
incremental_output = msg[len(previous_response) :]
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(role="assistant", content=incremental_output),
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=stream_id, choices=[choice_data], model=model_name
|
|
||||||
)
|
|
||||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
|
||||||
else:
|
|
||||||
# TODO generate an openai-compatible streaming responses
|
|
||||||
msg = msg.replace("\n", "\\n")
|
|
||||||
yield f"data:{msg}\n\n"
|
|
||||||
previous_response = msg
|
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
if incremental:
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||||
"""Generate streaming responses
|
"""Generate streaming responses
|
||||||
|
|
||||||
|
@@ -632,16 +632,36 @@ class BaseMetadata(BaseResource):
|
|||||||
runnable_parameters: Dict[str, Any] = {}
|
runnable_parameters: Dict[str, Any] = {}
|
||||||
if not self.parameters or not view_parameters:
|
if not self.parameters or not view_parameters:
|
||||||
return runnable_parameters
|
return runnable_parameters
|
||||||
if len(self.parameters) != len(view_parameters):
|
view_required_parameters = {
|
||||||
|
parameter.name: parameter
|
||||||
|
for parameter in view_parameters
|
||||||
|
if not parameter.optional
|
||||||
|
}
|
||||||
|
current_required_parameters = {
|
||||||
|
parameter.name: parameter
|
||||||
|
for parameter in self.parameters
|
||||||
|
if not parameter.optional
|
||||||
|
}
|
||||||
|
current_parameters = {
|
||||||
|
parameter.name: parameter for parameter in self.parameters
|
||||||
|
}
|
||||||
|
if len(view_required_parameters) < len(current_required_parameters):
|
||||||
# TODO, skip the optional parameters.
|
# TODO, skip the optional parameters.
|
||||||
raise FlowParameterMetadataException(
|
raise FlowParameterMetadataException(
|
||||||
f"Parameters count not match. Expected {len(self.parameters)}, "
|
f"Parameters count not match(current key: {self.id}). "
|
||||||
|
f"Expected {len(self.parameters)}, "
|
||||||
f"but got {len(view_parameters)} from JSON metadata."
|
f"but got {len(view_parameters)} from JSON metadata."
|
||||||
|
f"Required parameters: {current_required_parameters.keys()}, "
|
||||||
|
f"but got {view_required_parameters.keys()}."
|
||||||
)
|
)
|
||||||
for i, parameter in enumerate(self.parameters):
|
for view_param in view_parameters:
|
||||||
view_param = view_parameters[i]
|
view_param_key = view_param.name
|
||||||
|
if view_param_key not in current_parameters:
|
||||||
|
raise FlowParameterMetadataException(
|
||||||
|
f"Parameter {view_param_key} not found in the metadata."
|
||||||
|
)
|
||||||
runnable_parameters.update(
|
runnable_parameters.update(
|
||||||
parameter.to_runnable_parameter(
|
current_parameters[view_param_key].to_runnable_parameter(
|
||||||
view_param.get_typed_value(), resources, key_to_resource_instance
|
view_param.get_typed_value(), resources, key_to_resource_instance
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
This class extends DAGNode by adding execution capabilities.
|
This class extends DAGNode by adding execution capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
streaming_operator: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task_id: Optional[str] = None,
|
task_id: Optional[str] = None,
|
||||||
|
@@ -10,6 +10,8 @@ from .base import BaseOperator
|
|||||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||||
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
|
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
|
||||||
|
|
||||||
|
streaming_operator = True
|
||||||
|
|
||||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
call_data = curr_task_ctx.call_data
|
call_data = curr_task_ctx.call_data
|
||||||
@@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
|||||||
AsyncIterator[IN] to another AsyncIterator[OUT].
|
AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
streaming_operator = True
|
||||||
|
|
||||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||||
|
@@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
|
|||||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Pre fill the messages."""
|
"""Pre fill the messages."""
|
||||||
if "system_message" not in values:
|
if "system_message" not in values:
|
||||||
raise ValueError("No system message")
|
values["system_message"] = "You are a helpful AI Assistant."
|
||||||
if "human_message" not in values:
|
if "human_message" not in values:
|
||||||
raise ValueError("No human message")
|
values["human_message"] = "{user_input}"
|
||||||
if "message_placeholder" not in values:
|
if "message_placeholder" not in values:
|
||||||
raise ValueError("No message placeholder")
|
values["message_placeholder"] = "chat_history"
|
||||||
system_message = values.pop("system_message")
|
system_message = values.pop("system_message")
|
||||||
human_message = values.pop("human_message")
|
human_message = values.pop("human_message")
|
||||||
message_placeholder = values.pop("message_placeholder")
|
message_placeholder = values.pop("message_placeholder")
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List, Optional, cast
|
from typing import Any, List, Optional, cast
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ from dbgpt.core.awel import (
|
|||||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
|
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
|
||||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||||
|
from dbgpt.core.interface.llm import ModelOutput
|
||||||
from dbgpt.serve.core import BaseService
|
from dbgpt.serve.core import BaseService
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao
|
||||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||||
@@ -276,12 +278,39 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
"""
|
"""
|
||||||
return self.dao.get_list_page(request, page, page_size)
|
return self.dao.get_list_page(request, page, page_size)
|
||||||
|
|
||||||
async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
|
async def chat_flow(
|
||||||
|
self,
|
||||||
|
flow_uid: str,
|
||||||
|
request: CommonLLMHttpRequestBody,
|
||||||
|
incremental: bool = False,
|
||||||
|
):
|
||||||
"""Chat with the AWEL flow.
|
"""Chat with the AWEL flow.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flow_uid (str): The flow uid
|
flow_uid (str): The flow uid
|
||||||
request (CommonLLMHttpRequestBody): The request
|
request (CommonLLMHttpRequestBody): The request
|
||||||
|
incremental (bool): Whether to return the result incrementally
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for output in self._call_chat_flow(flow_uid, request, incremental):
|
||||||
|
yield output
|
||||||
|
except HTTPException as e:
|
||||||
|
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
|
||||||
|
|
||||||
|
async def _call_chat_flow(
|
||||||
|
self,
|
||||||
|
flow_uid: str,
|
||||||
|
request: CommonLLMHttpRequestBody,
|
||||||
|
incremental: bool = False,
|
||||||
|
):
|
||||||
|
"""Chat with the AWEL flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flow_uid (str): The flow uid
|
||||||
|
request (CommonLLMHttpRequestBody): The request
|
||||||
|
incremental (bool): Whether to return the result incrementally
|
||||||
"""
|
"""
|
||||||
flow = self.get({"uid": flow_uid})
|
flow = self.get({"uid": flow_uid})
|
||||||
if not flow:
|
if not flow:
|
||||||
@@ -291,18 +320,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
|
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
|
||||||
)
|
)
|
||||||
if flow.flow_category != FlowCategory.CHAT_FLOW:
|
|
||||||
raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
|
||||||
dag = self.dag_manager.dag_map[dag_id]
|
dag = self.dag_manager.dag_map[dag_id]
|
||||||
|
if (
|
||||||
|
flow.flow_category != FlowCategory.CHAT_FLOW
|
||||||
|
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
|
||||||
|
):
|
||||||
|
raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
||||||
leaf_nodes = dag.leaf_nodes
|
leaf_nodes = dag.leaf_nodes
|
||||||
if len(leaf_nodes) != 1:
|
if len(leaf_nodes) != 1:
|
||||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||||
if request.stream:
|
async for output in _chat_with_dag_task(end_node, request, incremental):
|
||||||
async for output in await end_node.call_stream(request):
|
yield output
|
||||||
yield output
|
|
||||||
else:
|
|
||||||
yield await end_node.call(request)
|
|
||||||
|
|
||||||
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
||||||
"""Parse the flow category
|
"""Parse the flow category
|
||||||
@@ -335,9 +364,104 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
output = leaf_node.metadata.outputs[0]
|
output = leaf_node.metadata.outputs[0]
|
||||||
try:
|
try:
|
||||||
real_class = _get_type_cls(output.type_cls)
|
real_class = _get_type_cls(output.type_cls)
|
||||||
if common_http_trigger and (
|
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
|
||||||
real_class == str or real_class == CommonLLMHttpResponseBody
|
|
||||||
):
|
|
||||||
return FlowCategory.CHAT_FLOW
|
return FlowCategory.CHAT_FLOW
|
||||||
except Exception:
|
except Exception:
|
||||||
return FlowCategory.COMMON
|
return FlowCategory.COMMON
|
||||||
|
|
||||||
|
|
||||||
|
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
|
||||||
|
try:
|
||||||
|
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||||
|
except ImportError:
|
||||||
|
OpenAIStreamingOutputOperator = None
|
||||||
|
if is_class:
|
||||||
|
return (
|
||||||
|
obj == str
|
||||||
|
or obj == CommonLLMHttpResponseBody
|
||||||
|
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_types = (str, CommonLLMHttpResponseBody)
|
||||||
|
if OpenAIStreamingOutputOperator:
|
||||||
|
chat_types += (OpenAIStreamingOutputOperator,)
|
||||||
|
return isinstance(obj, chat_types)
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_with_dag_task(
|
||||||
|
task: BaseOperator,
|
||||||
|
request: CommonLLMHttpRequestBody,
|
||||||
|
incremental: bool = False,
|
||||||
|
):
|
||||||
|
"""Chat with the DAG task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (BaseOperator): The task
|
||||||
|
request (CommonLLMHttpRequestBody): The request
|
||||||
|
"""
|
||||||
|
if request.stream and task.streaming_operator:
|
||||||
|
try:
|
||||||
|
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||||
|
except ImportError:
|
||||||
|
OpenAIStreamingOutputOperator = None
|
||||||
|
if incremental:
|
||||||
|
async for output in await task.call_stream(request):
|
||||||
|
yield output
|
||||||
|
else:
|
||||||
|
if OpenAIStreamingOutputOperator and isinstance(
|
||||||
|
task, OpenAIStreamingOutputOperator
|
||||||
|
):
|
||||||
|
from fastchat.protocol.openai_api_protocol import (
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_text = ""
|
||||||
|
async for output in await task.call_stream(request):
|
||||||
|
if not isinstance(output, str):
|
||||||
|
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||||
|
return
|
||||||
|
if output == "data: [DONE]\n\n":
|
||||||
|
return
|
||||||
|
json_data = "".join(output.split("data: ")[1:])
|
||||||
|
dict_data = json.loads(json_data)
|
||||||
|
if "choices" not in dict_data:
|
||||||
|
error_msg = dict_data.get("text", "Unknown error")
|
||||||
|
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
|
||||||
|
return
|
||||||
|
choices = dict_data["choices"]
|
||||||
|
if choices:
|
||||||
|
choice = choices[0]
|
||||||
|
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||||
|
if delta_data.delta.content:
|
||||||
|
previous_text += delta_data.delta.content
|
||||||
|
if previous_text:
|
||||||
|
full_text = previous_text.replace("\n", "\\n")
|
||||||
|
yield f"data:{full_text}\n\n"
|
||||||
|
else:
|
||||||
|
async for output in await task.call_stream(request):
|
||||||
|
if isinstance(output, str):
|
||||||
|
if output.strip():
|
||||||
|
yield output
|
||||||
|
else:
|
||||||
|
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
result = await task.call(request)
|
||||||
|
if result is None:
|
||||||
|
yield "data:[SERVER_ERROR]The result is None\n\n"
|
||||||
|
elif isinstance(result, str):
|
||||||
|
yield f"data:{result}\n\n"
|
||||||
|
elif isinstance(result, ModelOutput):
|
||||||
|
if result.error_code != 0:
|
||||||
|
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data:{result.text}\n\n"
|
||||||
|
elif isinstance(result, CommonLLMHttpResponseBody):
|
||||||
|
if result.error_code != 0:
|
||||||
|
yield f"data:[SERVER_ERROR]{result.text}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data:{result.text}\n\n"
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
|
||||||
|
@@ -140,7 +140,7 @@ def update_repo(repo: str):
|
|||||||
logger.info(f"Repo '{repo}' is not a git repository.")
|
logger.info(f"Repo '{repo}' is not a git repository.")
|
||||||
return
|
return
|
||||||
logger.info(f"Updating repo '{repo}'...")
|
logger.info(f"Updating repo '{repo}'...")
|
||||||
subprocess.run(["git", "pull"], check=True)
|
subprocess.run(["git", "pull"], check=False)
|
||||||
|
|
||||||
|
|
||||||
def install(
|
def install(
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
# Upgrade To v0.5.0(Draft)
|
# Upgrade To v0.5.0
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@@ -18,7 +18,7 @@ with open("README.md", mode="r", encoding="utf-8") as fh:
|
|||||||
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
||||||
# If you modify the version, please modify the version in the following files:
|
# If you modify the version, please modify the version in the following files:
|
||||||
# dbgpt/_version.py
|
# dbgpt/_version.py
|
||||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
|
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0")
|
||||||
|
|
||||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||||
LLAMA_CPP_GPU_ACCELERATION = (
|
LLAMA_CPP_GPU_ACCELERATION = (
|
||||||
|
Reference in New Issue
Block a user