feat(core): Support more chat flows (#1180)

This commit is contained in:
Fangyin Cheng
2024-02-22 12:19:04 +08:00
committed by GitHub
parent 16fa68d4f2
commit ab5e1c7ea1
10 changed files with 175 additions and 55 deletions

View File

@@ -1 +1 @@
version = "0.4.7" version = "0.5.0"

View File

@@ -366,11 +366,7 @@ async def chat_completions(
context=flow_ctx, context=flow_ctx,
) )
return StreamingResponse( return StreamingResponse(
flow_stream_generator( flow_service.chat_flow(dialogue.select_param, flow_req),
flow_service.chat_flow(dialogue.select_param, flow_req),
dialogue.incremental,
dialogue.model_name,
),
headers=headers, headers=headers,
media_type="text/event-stream", media_type="text/event-stream",
) )
@@ -426,32 +422,6 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n" yield f"data: {msg}\n\n"
async def flow_stream_generator(func, incremental: bool, model_name: str):
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in func:
if chunk:
msg = chunk.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
async def stream_generator(chat, incremental: bool, model_name: str): async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses """Generate streaming responses

View File

@@ -632,16 +632,36 @@ class BaseMetadata(BaseResource):
runnable_parameters: Dict[str, Any] = {} runnable_parameters: Dict[str, Any] = {}
if not self.parameters or not view_parameters: if not self.parameters or not view_parameters:
return runnable_parameters return runnable_parameters
if len(self.parameters) != len(view_parameters): view_required_parameters = {
parameter.name: parameter
for parameter in view_parameters
if not parameter.optional
}
current_required_parameters = {
parameter.name: parameter
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters. # TODO, skip the optional parameters.
raise FlowParameterMetadataException( raise FlowParameterMetadataException(
f"Parameters count not match. Expected {len(self.parameters)}, " f"Parameters count not match(current key: {self.id}). "
f"Expected {len(self.parameters)}, "
f"but got {len(view_parameters)} from JSON metadata." f"but got {len(view_parameters)} from JSON metadata."
f"Required parameters: {current_required_parameters.keys()}, "
f"but got {view_required_parameters.keys()}."
) )
for i, parameter in enumerate(self.parameters): for view_param in view_parameters:
view_param = view_parameters[i] view_param_key = view_param.name
if view_param_key not in current_parameters:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not found in the metadata."
)
runnable_parameters.update( runnable_parameters.update(
parameter.to_runnable_parameter( current_parameters[view_param_key].to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance view_param.get_typed_value(), resources, key_to_resource_instance
) )
) )

View File

@@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
This class extends DAGNode by adding execution capabilities. This class extends DAGNode by adding execution capabilities.
""" """
streaming_operator: bool = False
def __init__( def __init__(
self, self,
task_id: Optional[str] = None, task_id: Optional[str] = None,

View File

@@ -10,6 +10,8 @@ from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]): class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT].""" """An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
streaming_operator = True
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data call_data = curr_task_ctx.call_data
@@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
AsyncIterator[IN] to another AsyncIterator[OUT]. AsyncIterator[IN] to another AsyncIterator[OUT].
""" """
streaming_operator = True
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[ output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[

View File

@@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages.""" """Pre fill the messages."""
if "system_message" not in values: if "system_message" not in values:
raise ValueError("No system message") values["system_message"] = "You are a helpful AI Assistant."
if "human_message" not in values: if "human_message" not in values:
raise ValueError("No human message") values["human_message"] = "{user_input}"
if "message_placeholder" not in values: if "message_placeholder" not in values:
raise ValueError("No message placeholder") values["message_placeholder"] = "chat_history"
system_message = values.pop("system_message") system_message = values.pop("system_message")
human_message = values.pop("human_message") human_message = values.pop("human_message")
message_placeholder = values.pop("message_placeholder") message_placeholder = values.pop("message_placeholder")

View File

@@ -1,6 +1,7 @@
import json
import logging import logging
import traceback import traceback
from typing import List, Optional, cast from typing import Any, List, Optional, cast
from fastapi import HTTPException from fastapi import HTTPException
@@ -14,6 +15,7 @@ from dbgpt.core.awel import (
from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.serve.core import BaseService from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC from dbgpt.storage.metadata._base_dao import QUERY_SPEC
@@ -276,12 +278,39 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
""" """
return self.dao.get_list_page(request, page, page_size) return self.dao.get_list_page(request, page, page_size)
async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody): async def chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow. """Chat with the AWEL flow.
Args: Args:
flow_uid (str): The flow uid flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
try:
async for output in self._call_chat_flow(flow_uid, request, incremental):
yield output
except HTTPException as e:
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
async def _call_chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
""" """
flow = self.get({"uid": flow_uid}) flow = self.get({"uid": flow_uid})
if not flow: if not flow:
@@ -291,18 +320,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Flow {flow_uid}'s dag id not found" status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
) )
if flow.flow_category != FlowCategory.CHAT_FLOW:
raise ValueError(f"Flow {flow_uid} is not a chat flow")
dag = self.dag_manager.dag_map[dag_id] dag = self.dag_manager.dag_map[dag_id]
if (
flow.flow_category != FlowCategory.CHAT_FLOW
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
):
raise ValueError(f"Flow {flow_uid} is not a chat flow")
leaf_nodes = dag.leaf_nodes leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1: if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag") raise ValueError("Chat Flow just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0]) end_node = cast(BaseOperator, leaf_nodes[0])
if request.stream: async for output in _chat_with_dag_task(end_node, request, incremental):
async for output in await end_node.call_stream(request): yield output
yield output
else:
yield await end_node.call(request)
def _parse_flow_category(self, dag: DAG) -> FlowCategory: def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category """Parse the flow category
@@ -335,9 +364,104 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
output = leaf_node.metadata.outputs[0] output = leaf_node.metadata.outputs[0]
try: try:
real_class = _get_type_cls(output.type_cls) real_class = _get_type_cls(output.type_cls)
if common_http_trigger and ( if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
real_class == str or real_class == CommonLLMHttpResponseBody
):
return FlowCategory.CHAT_FLOW return FlowCategory.CHAT_FLOW
except Exception: except Exception:
return FlowCategory.COMMON return FlowCategory.COMMON
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if is_class:
return (
obj == str
or obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator:
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)
async def _chat_with_dag_task(
task: BaseOperator,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the DAG task.
Args:
task (BaseOperator): The task
request (CommonLLMHttpRequestBody): The request
"""
if request.stream and task.streaming_operator:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if incremental:
async for output in await task.call_stream(request):
yield output
else:
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
)
previous_text = ""
async for output in await task.call_stream(request):
if not isinstance(output, str):
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
if output == "data: [DONE]\n\n":
return
json_data = "".join(output.split("data: ")[1:])
dict_data = json.loads(json_data)
if "choices" not in dict_data:
error_msg = dict_data.get("text", "Unknown error")
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
return
choices = dict_data["choices"]
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
previous_text += delta_data.delta.content
if previous_text:
full_text = previous_text.replace("\n", "\\n")
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
if isinstance(output, str):
if output.strip():
yield output
else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
else:
result = await task.call(request)
if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n"
elif isinstance(result, str):
yield f"data:{result}\n\n"
elif isinstance(result, ModelOutput):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"

View File

@@ -140,7 +140,7 @@ def update_repo(repo: str):
logger.info(f"Repo '{repo}' is not a git repository.") logger.info(f"Repo '{repo}' is not a git repository.")
return return
logger.info(f"Updating repo '{repo}'...") logger.info(f"Updating repo '{repo}'...")
subprocess.run(["git", "pull"], check=True) subprocess.run(["git", "pull"], check=False)
def install( def install(

View File

@@ -1,4 +1,4 @@
# Upgrade To v0.5.0(Draft) # Upgrade To v0.5.0
## Overview ## Overview

View File

@@ -18,7 +18,7 @@ with open("README.md", mode="r", encoding="utf-8") as fh:
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true" IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files: # If you modify the version, please modify the version in the following files:
# dbgpt/_version.py # dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7") DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0")
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true" BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = ( LLAMA_CPP_GPU_ACCELERATION = (