feat: Run AWEL flow in CLI (#1341)

This commit is contained in:
Fangyin Cheng
2024-03-27 12:50:05 +08:00
committed by GitHub
parent 340a9fbc35
commit 3a7a2cbbb8
42 changed files with 1454 additions and 422 deletions

View File

@@ -180,6 +180,8 @@ async def query_page(
sys_code: Optional[str] = Query(default=None, description="system code"),
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
name: Optional[str] = Query(default=None, description="flow name"),
uid: Optional[str] = Query(default=None, description="flow uid"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Flow entities
@@ -189,13 +191,17 @@ async def query_page(
sys_code (Optional[str]): The system code
page (int): The page number
page_size (int): The page size
name (Optional[str]): The flow name
uid (Optional[str]): The flow uid
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(
service.get_list_by_page(
{"user_name": user_name, "sys_code": sys_code}, page, page_size
{"user_name": user_name, "sys_code": sys_code, "name": name, "uid": uid},
page,
page_size,
)
)

View File

@@ -1,11 +1,13 @@
import json
import logging
import time
import traceback
from typing import Any, List, Optional, cast
from typing import Any, AsyncIterator, List, Optional, cast
import schedule
from fastapi import HTTPException
from dbgpt._private.pydantic import model_to_json
from dbgpt.component import SystemApp
from dbgpt.core.awel import (
DAG,
@@ -22,6 +24,13 @@ from dbgpt.core.awel.flow.flow_factory import (
)
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.schema.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
@@ -365,39 +374,117 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""
return self.dao.get_list_page(request, page, page_size)
async def chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
async def chat_stream_flow_str(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[str]:
"""Stream chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
"""
# Must be non-incremental
request.incremental = False
async for output in self.safe_chat_stream_flow(flow_uid, request):
text = output.text
# if text:
# text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"
async def chat_stream_openai(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[str]:
conv_uid = request.conv_uid
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=conv_uid, choices=[choice_data], model=request.model
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
request.incremental = True
async for output in self.safe_chat_stream_flow(flow_uid, request):
if not output.success:
yield f"data: {json.dumps(output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=output.text),
)
chunk = ChatCompletionStreamResponse(
id=conv_uid,
choices=[choice_data],
model=request.model,
)
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
yield "data: [DONE]\n\n"
async def safe_chat_flow(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> ModelOutput:
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
Returns:
ModelOutput: The output
"""
incremental = request.incremental
try:
async for output in self._call_chat_flow(flow_uid, request, incremental):
task = await self._get_callable_task(flow_uid)
return await _safe_chat_with_dag_task(task, request)
except HTTPException as e:
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def safe_chat_stream_flow(
self, flow_uid: str, request: CommonLLMHttpRequestBody
) -> AsyncIterator[ModelOutput]:
"""Stream chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
Returns:
AsyncIterator[ModelOutput]: The output
"""
incremental = request.incremental
try:
task = await self._get_callable_task(flow_uid)
async for output in _safe_chat_stream_with_dag_task(
task, request, incremental
):
yield output
except HTTPException as e:
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _call_chat_flow(
async def _get_callable_task(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.
) -> BaseOperator:
"""Return the callable task.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
Returns:
BaseOperator: The callable task
Raises:
HTTPException: If the flow is not found
ValueError: If the flow is not a chat flow or the leaf node is not found.
"""
flow = self.get({"uid": flow_uid})
if not flow:
@@ -416,10 +503,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
async for output in _chat_with_dag_task(end_node, request, incremental):
yield output
await dag._after_dag_end(end_node.current_event_loop_task_id)
return cast(BaseOperator, leaf_nodes[0])
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category
@@ -470,99 +554,202 @@ def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
return isinstance(output_obj, chat_types)
async def _chat_with_dag_task(
async def _safe_chat_with_dag_task(task: BaseOperator, request: Any) -> ModelOutput:
"""Chat with the DAG task."""
try:
finish_reason = None
usage = None
metrics = None
error_code = 0
text = ""
async for output in _safe_chat_stream_with_dag_task(task, request, False):
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
error_code = output.error_code
text = output.text
return ModelOutput(
error_code=error_code,
text=text,
metrics=metrics,
usage=usage,
finish_reason=finish_reason,
)
except Exception as e:
return ModelOutput(error_code=1, text=str(e), incremental=False)
async def _safe_chat_stream_with_dag_task(
task: BaseOperator,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the DAG task.
request: Any,
incremental: bool,
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task."""
try:
async for output in _chat_stream_with_dag_task(task, request, incremental):
yield output
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
finally:
if task.streaming_operator:
if task.dag:
await task.dag._after_dag_end(task.current_event_loop_task_id)
Args:
task (BaseOperator): The task
request (CommonLLMHttpRequestBody): The request
"""
if request.stream and task.streaming_operator:
async def _chat_stream_with_dag_task(
task: BaseOperator,
request: Any,
incremental: bool,
) -> AsyncIterator[ModelOutput]:
"""Chat with the DAG task."""
is_sse = task.output_format and task.output_format.upper() == "SSE"
if not task.streaming_operator:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if incremental:
async for output in await task.call_stream(request):
yield output
else:
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
from dbgpt.core.schema.api import ChatCompletionResponseStreamChoice
previous_text = ""
async for output in await task.call_stream(request):
if not isinstance(output, str):
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
if output == "data: [DONE]\n\n":
return
json_data = "".join(output.split("data: ")[1:])
dict_data = json.loads(json_data)
if "choices" not in dict_data:
error_msg = dict_data.get("text", "Unknown error")
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
return
choices = dict_data["choices"]
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
previous_text += delta_data.delta.content
if previous_text:
full_text = previous_text.replace("\n", "\\n")
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
str_msg = ""
should_return = False
if isinstance(output, str):
if output.strip():
str_msg = output
elif isinstance(output, ModelOutput):
if output.error_code != 0:
str_msg = f"[SERVER_ERROR]{output.text}"
should_return = True
else:
str_msg = output.text
else:
str_msg = (
f"[SERVER_ERROR]The output is not a valid format"
f"({type(output)})"
)
should_return = True
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"
if should_return:
return
result = await task.call(request)
model_output = _parse_single_output(result, is_sse)
model_output.incremental = incremental
yield model_output
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
else:
result = await task.call(request)
str_msg = ""
if result is None:
str_msg = "[SERVER_ERROR]The result is None!"
elif isinstance(result, str):
str_msg = result
elif isinstance(result, ModelOutput):
if result.error_code != 0:
str_msg = f"[SERVER_ERROR]{result.text}"
else:
str_msg = result.text
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
str_msg = f"[SERVER_ERROR]{result.text}"
else:
str_msg = result.text
elif isinstance(result, dict):
str_msg = json.dumps(result, ensure_ascii=False)
else:
str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
full_text = ""
async for output in await task.call_stream(request):
model_output = _parse_openai_output(output)
# The output of the OpenAI streaming API is incremental
full_text += model_output.text
model_output.incremental = incremental
model_output.text = model_output.text if incremental else full_text
yield model_output
if not model_output.success:
break
else:
full_text = ""
previous_text = ""
async for output in await task.call_stream(request):
model_output = _parse_single_output(output, is_sse)
model_output.incremental = incremental
if task.incremental_output:
# Output is incremental, append the text
full_text += model_output.text
else:
# Output is not incremental, last output is the full text
full_text = model_output.text
if not incremental:
# Return the full text
model_output.text = full_text
else:
# Return the incremental text
delta_text = full_text[len(previous_text) :]
previous_text = (
full_text
if len(full_text) > len(previous_text)
else previous_text
)
model_output.text = delta_text
yield model_output
if not model_output.success:
break
def _parse_single_output(output: Any, is_sse: bool) -> ModelOutput:
"""Parse the single output."""
finish_reason = None
usage = None
metrics = None
if output is None:
error_code = 1
text = "The output is None!"
elif isinstance(output, str):
if is_sse:
sse_output = _parse_sse_data(output)
if sse_output is None:
error_code = 1
text = "The output is not a SSE format"
else:
error_code = 0
text = sse_output
else:
error_code = 0
text = output
elif isinstance(output, ModelOutput):
error_code = output.error_code
text = output.text
finish_reason = output.finish_reason
usage = output.usage
metrics = output.metrics
elif isinstance(output, CommonLLMHttpResponseBody):
error_code = output.error_code
text = output.text
elif isinstance(output, dict):
error_code = 0
text = json.dumps(output, ensure_ascii=False)
else:
error_code = 1
text = f"The output is not a valid format({type(output)})"
return ModelOutput(
error_code=error_code,
text=text,
finish_reason=finish_reason,
usage=usage,
metrics=metrics,
)
def _parse_openai_output(output: Any) -> ModelOutput:
"""Parse the OpenAI output."""
text = ""
if not isinstance(output, str):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
if output.strip() == "data: [DONE]" or output.strip() == "data:[DONE]":
return ModelOutput(error_code=0, text="")
if not output.startswith("data:"):
return ModelOutput(
error_code=1,
text="The output is not a stream format",
)
sse_output = _parse_sse_data(output)
if sse_output is None:
return ModelOutput(error_code=1, text="The output is not a SSE format")
json_data = sse_output.strip()
try:
dict_data = json.loads(json_data)
except Exception as e:
return ModelOutput(
error_code=1,
text=f"Invalid JSON data: {json_data}, {e}",
)
if "choices" not in dict_data:
return ModelOutput(
error_code=1,
text=dict_data.get("text", "Unknown error"),
)
choices = dict_data["choices"]
finish_reason: Optional[str] = None
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
text = delta_data.delta.content
finish_reason = delta_data.finish_reason
return ModelOutput(error_code=0, text=text, finish_reason=finish_reason)
def _parse_sse_data(output: str) -> Optional[str]:
if output.startswith("data:"):
if output.startswith("data: "):
output = output[6:]
else:
output = output[5:]
return output
else:
return None