mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
feat: Run AWEL flow in CLI (#1341)
This commit is contained in:
@@ -180,6 +180,8 @@ async def query_page(
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
name: Optional[str] = Query(default=None, description="flow name"),
|
||||
uid: Optional[str] = Query(default=None, description="flow uid"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[ServerResponse]]:
|
||||
"""Query Flow entities
|
||||
@@ -189,13 +191,17 @@ async def query_page(
|
||||
sys_code (Optional[str]): The system code
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
name (Optional[str]): The flow name
|
||||
uid (Optional[str]): The flow uid
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(
|
||||
service.get_list_by_page(
|
||||
{"user_name": user_name, "sys_code": sys_code}, page, page_size
|
||||
{"user_name": user_name, "sys_code": sys_code, "name": name, "uid": uid},
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, AsyncIterator, List, Optional, cast
|
||||
|
||||
import schedule
|
||||
from fastapi import HTTPException
|
||||
|
||||
from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
@@ -22,6 +24,13 @@ from dbgpt.core.awel.flow.flow_factory import (
|
||||
)
|
||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
)
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
@@ -365,39 +374,117 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
return self.dao.get_list_page(request, page, page_size)
|
||||
|
||||
async def chat_flow(
|
||||
self,
|
||||
flow_uid: str,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
async def chat_stream_flow_str(
|
||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
"""
|
||||
# Must be non-incremental
|
||||
request.incremental = False
|
||||
async for output in self.safe_chat_stream_flow(flow_uid, request):
|
||||
text = output.text
|
||||
# if text:
|
||||
# text = text.replace("\n", "\\n")
|
||||
if output.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{text}\n\n"
|
||||
break
|
||||
else:
|
||||
yield f"data:{text}\n\n"
|
||||
|
||||
async def chat_stream_openai(
|
||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||
) -> AsyncIterator[str]:
|
||||
conv_uid = request.conv_uid
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=conv_uid, choices=[choice_data], model=request.model
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
request.incremental = True
|
||||
async for output in self.safe_chat_stream_flow(flow_uid, request):
|
||||
if not output.success:
|
||||
yield f"data: {json.dumps(output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=output.text),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=conv_uid,
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def safe_chat_flow(
|
||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||
) -> ModelOutput:
|
||||
"""Chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
incremental (bool): Whether to return the result incrementally
|
||||
|
||||
Returns:
|
||||
ModelOutput: The output
|
||||
"""
|
||||
incremental = request.incremental
|
||||
try:
|
||||
async for output in self._call_chat_flow(flow_uid, request, incremental):
|
||||
task = await self._get_callable_task(flow_uid)
|
||||
return await _safe_chat_with_dag_task(task, request)
|
||||
except HTTPException as e:
|
||||
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||
except Exception as e:
|
||||
return ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
|
||||
async def safe_chat_stream_flow(
|
||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Stream chat with the AWEL flow.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The output
|
||||
"""
|
||||
incremental = request.incremental
|
||||
try:
|
||||
task = await self._get_callable_task(flow_uid)
|
||||
async for output in _safe_chat_stream_with_dag_task(
|
||||
task, request, incremental
|
||||
):
|
||||
yield output
|
||||
except HTTPException as e:
|
||||
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
|
||||
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||
except Exception as e:
|
||||
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
|
||||
async def _call_chat_flow(
|
||||
async def _get_callable_task(
|
||||
self,
|
||||
flow_uid: str,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the AWEL flow.
|
||||
) -> BaseOperator:
|
||||
"""Return the callable task.
|
||||
|
||||
Args:
|
||||
flow_uid (str): The flow uid
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
incremental (bool): Whether to return the result incrementally
|
||||
Returns:
|
||||
BaseOperator: The callable task
|
||||
|
||||
Raises:
|
||||
HTTPException: If the flow is not found
|
||||
ValueError: If the flow is not a chat flow or the leaf node is not found.
|
||||
"""
|
||||
flow = self.get({"uid": flow_uid})
|
||||
if not flow:
|
||||
@@ -416,10 +503,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
async for output in _chat_with_dag_task(end_node, request, incremental):
|
||||
yield output
|
||||
await dag._after_dag_end(end_node.current_event_loop_task_id)
|
||||
return cast(BaseOperator, leaf_nodes[0])
|
||||
|
||||
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
||||
"""Parse the flow category
|
||||
@@ -470,99 +554,202 @@ def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
async def _chat_with_dag_task(
|
||||
async def _safe_chat_with_dag_task(task: BaseOperator, request: Any) -> ModelOutput:
|
||||
"""Chat with the DAG task."""
|
||||
try:
|
||||
finish_reason = None
|
||||
usage = None
|
||||
metrics = None
|
||||
error_code = 0
|
||||
text = ""
|
||||
async for output in _safe_chat_stream_with_dag_task(task, request, False):
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
metrics=metrics,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return ModelOutput(error_code=1, text=str(e), incremental=False)
|
||||
|
||||
|
||||
async def _safe_chat_stream_with_dag_task(
|
||||
task: BaseOperator,
|
||||
request: CommonLLMHttpRequestBody,
|
||||
incremental: bool = False,
|
||||
):
|
||||
"""Chat with the DAG task.
|
||||
request: Any,
|
||||
incremental: bool,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task."""
|
||||
try:
|
||||
async for output in _chat_stream_with_dag_task(task, request, incremental):
|
||||
yield output
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
finally:
|
||||
if task.streaming_operator:
|
||||
if task.dag:
|
||||
await task.dag._after_dag_end(task.current_event_loop_task_id)
|
||||
|
||||
Args:
|
||||
task (BaseOperator): The task
|
||||
request (CommonLLMHttpRequestBody): The request
|
||||
"""
|
||||
if request.stream and task.streaming_operator:
|
||||
|
||||
async def _chat_stream_with_dag_task(
|
||||
task: BaseOperator,
|
||||
request: Any,
|
||||
incremental: bool,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Chat with the DAG task."""
|
||||
is_sse = task.output_format and task.output_format.upper() == "SSE"
|
||||
if not task.streaming_operator:
|
||||
try:
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
except ImportError:
|
||||
OpenAIStreamingOutputOperator = None
|
||||
if incremental:
|
||||
async for output in await task.call_stream(request):
|
||||
yield output
|
||||
else:
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
from dbgpt.core.schema.api import ChatCompletionResponseStreamChoice
|
||||
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
if not isinstance(output, str):
|
||||
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
|
||||
return
|
||||
if output == "data: [DONE]\n\n":
|
||||
return
|
||||
json_data = "".join(output.split("data: ")[1:])
|
||||
dict_data = json.loads(json_data)
|
||||
if "choices" not in dict_data:
|
||||
error_msg = dict_data.get("text", "Unknown error")
|
||||
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
|
||||
return
|
||||
choices = dict_data["choices"]
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
previous_text += delta_data.delta.content
|
||||
if previous_text:
|
||||
full_text = previous_text.replace("\n", "\\n")
|
||||
yield f"data:{full_text}\n\n"
|
||||
else:
|
||||
async for output in await task.call_stream(request):
|
||||
str_msg = ""
|
||||
should_return = False
|
||||
if isinstance(output, str):
|
||||
if output.strip():
|
||||
str_msg = output
|
||||
elif isinstance(output, ModelOutput):
|
||||
if output.error_code != 0:
|
||||
str_msg = f"[SERVER_ERROR]{output.text}"
|
||||
should_return = True
|
||||
else:
|
||||
str_msg = output.text
|
||||
else:
|
||||
str_msg = (
|
||||
f"[SERVER_ERROR]The output is not a valid format"
|
||||
f"({type(output)})"
|
||||
)
|
||||
should_return = True
|
||||
if str_msg:
|
||||
str_msg = str_msg.replace("\n", "\\n")
|
||||
yield f"data:{str_msg}\n\n"
|
||||
if should_return:
|
||||
return
|
||||
result = await task.call(request)
|
||||
model_output = _parse_single_output(result, is_sse)
|
||||
model_output.incremental = incremental
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
else:
|
||||
result = await task.call(request)
|
||||
str_msg = ""
|
||||
if result is None:
|
||||
str_msg = "[SERVER_ERROR]The result is None!"
|
||||
elif isinstance(result, str):
|
||||
str_msg = result
|
||||
elif isinstance(result, ModelOutput):
|
||||
if result.error_code != 0:
|
||||
str_msg = f"[SERVER_ERROR]{result.text}"
|
||||
else:
|
||||
str_msg = result.text
|
||||
elif isinstance(result, CommonLLMHttpResponseBody):
|
||||
if result.error_code != 0:
|
||||
str_msg = f"[SERVER_ERROR]{result.text}"
|
||||
else:
|
||||
str_msg = result.text
|
||||
elif isinstance(result, dict):
|
||||
str_msg = json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
|
||||
if str_msg:
|
||||
str_msg = str_msg.replace("\n", "\\n")
|
||||
yield f"data:{str_msg}\n\n"
|
||||
if OpenAIStreamingOutputOperator and isinstance(
|
||||
task, OpenAIStreamingOutputOperator
|
||||
):
|
||||
full_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = _parse_openai_output(output)
|
||||
# The output of the OpenAI streaming API is incremental
|
||||
full_text += model_output.text
|
||||
model_output.incremental = incremental
|
||||
model_output.text = model_output.text if incremental else full_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
else:
|
||||
full_text = ""
|
||||
previous_text = ""
|
||||
async for output in await task.call_stream(request):
|
||||
model_output = _parse_single_output(output, is_sse)
|
||||
model_output.incremental = incremental
|
||||
if task.incremental_output:
|
||||
# Output is incremental, append the text
|
||||
full_text += model_output.text
|
||||
else:
|
||||
# Output is not incremental, last output is the full text
|
||||
full_text = model_output.text
|
||||
if not incremental:
|
||||
# Return the full text
|
||||
model_output.text = full_text
|
||||
else:
|
||||
# Return the incremental text
|
||||
delta_text = full_text[len(previous_text) :]
|
||||
previous_text = (
|
||||
full_text
|
||||
if len(full_text) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
model_output.text = delta_text
|
||||
yield model_output
|
||||
if not model_output.success:
|
||||
break
|
||||
|
||||
|
||||
def _parse_single_output(output: Any, is_sse: bool) -> ModelOutput:
|
||||
"""Parse the single output."""
|
||||
finish_reason = None
|
||||
usage = None
|
||||
metrics = None
|
||||
if output is None:
|
||||
error_code = 1
|
||||
text = "The output is None!"
|
||||
elif isinstance(output, str):
|
||||
if is_sse:
|
||||
sse_output = _parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
error_code = 1
|
||||
text = "The output is not a SSE format"
|
||||
else:
|
||||
error_code = 0
|
||||
text = sse_output
|
||||
else:
|
||||
error_code = 0
|
||||
text = output
|
||||
elif isinstance(output, ModelOutput):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
finish_reason = output.finish_reason
|
||||
usage = output.usage
|
||||
metrics = output.metrics
|
||||
elif isinstance(output, CommonLLMHttpResponseBody):
|
||||
error_code = output.error_code
|
||||
text = output.text
|
||||
elif isinstance(output, dict):
|
||||
error_code = 0
|
||||
text = json.dumps(output, ensure_ascii=False)
|
||||
else:
|
||||
error_code = 1
|
||||
text = f"The output is not a valid format({type(output)})"
|
||||
return ModelOutput(
|
||||
error_code=error_code,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
|
||||
def _parse_openai_output(output: Any) -> ModelOutput:
|
||||
"""Parse the OpenAI output."""
|
||||
text = ""
|
||||
if not isinstance(output, str):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
|
||||
return ModelOutput(error_code=0, text="")
|
||||
if not output.startswith("data:"):
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text="The output is not a stream format",
|
||||
)
|
||||
|
||||
sse_output = _parse_sse_data(output)
|
||||
if sse_output is None:
|
||||
return ModelOutput(error_code=1, text="The output is not a SSE format")
|
||||
json_data = sse_output.strip()
|
||||
try:
|
||||
dict_data = json.loads(json_data)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=f"Invalid JSON data: {json_data}, {e}",
|
||||
)
|
||||
if "choices" not in dict_data:
|
||||
return ModelOutput(
|
||||
error_code=1,
|
||||
text=dict_data.get("text", "Unknown error"),
|
||||
)
|
||||
choices = dict_data["choices"]
|
||||
finish_reason: Optional[str] = None
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
delta_data = ChatCompletionResponseStreamChoice(**choice)
|
||||
if delta_data.delta.content:
|
||||
text = delta_data.delta.content
|
||||
finish_reason = delta_data.finish_reason
|
||||
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
|
||||
|
||||
|
||||
def _parse_sse_data(output: str) -> Optional[str]:
|
||||
if output.startswith("data:"):
|
||||
if output.startswith("data: "):
|
||||
output = output[6:]
|
||||
else:
|
||||
output = output[5:]
|
||||
|
||||
return output
|
||||
else:
|
||||
return None
|
||||
|
Reference in New Issue
Block a user