feat(core): AWEL flow 2.0 backend code (#1879)

Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
Fangyin Cheng
2024-08-23 14:57:54 +08:00
committed by GitHub
parent 3a32344380
commit 9502251c08
67 changed files with 8289 additions and 190 deletions

View File

@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
FlowFactory,
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise ValueError(
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
) from e
res = self.dao.create(request)
self.dao.create(request)
# Query from database
res = self.get({"uid": request.uid})
state = request.state
try:
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
async def debug_flow(
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
) -> AsyncIterator[ModelOutput]:
"""Debug the flow.
Args:
request (FlowDebugRequest): The request
default_incremental (Optional[bool]): The default incremental configuration
Returns:
AsyncIterator[ModelOutput]: The output
"""
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
dag = self._flow_factory.build(request.flow)
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
task = cast(BaseOperator, leaf_nodes[0])
dag_metadata = _parse_metadata(dag)
# TODO: Run task with variables
variables = request.variables
dag_request = request.request
if isinstance(request.request, CommonLLMHttpRequestBody):
incremental = request.request.incremental
elif isinstance(request.request, dict):
incremental = request.request.get("incremental", False)
else:
raise ValueError("Invalid request type")
if default_incremental is not None:
incremental = default_incremental
try:
async for output in safe_chat_stream_with_dag_task(
task, dag_request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _wrapper_chat_stream_flow_str(
self, stream_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async for output in stream_iter:
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"