mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 01:50:08 +00:00
700 lines
25 KiB
Python
700 lines
25 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from typing import AsyncIterator, List, Optional, cast
|
|
|
|
import schedule
|
|
from fastapi import HTTPException
|
|
|
|
from dbgpt._private.pydantic import model_to_json
|
|
from dbgpt.agent import AgentDummyTrigger
|
|
from dbgpt.component import SystemApp
|
|
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
|
from dbgpt.core.awel.flow.flow_factory import (
|
|
FlowCategory,
|
|
FlowFactory,
|
|
State,
|
|
fill_flow_panel,
|
|
)
|
|
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
|
from dbgpt.core.awel.util.chat_util import (
|
|
is_chat_flow_type,
|
|
safe_chat_stream_with_dag_task,
|
|
safe_chat_with_dag_task,
|
|
)
|
|
from dbgpt.core.interface.llm import ModelOutput
|
|
from dbgpt.core.schema.api import (
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
DeltaMessage,
|
|
)
|
|
from dbgpt.serve.core import BaseService, blocking_func_to_async
|
|
from dbgpt.storage.metadata import BaseDao
|
|
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
|
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
|
from dbgpt.util.pagination_utils import PaginationResult
|
|
|
|
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
|
|
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
|
from ..models.models import ServeDao, ServeEntity
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|
"""The service class for Flow"""
|
|
|
|
name = SERVE_SERVICE_COMPONENT_NAME
|
|
|
|
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
|
self._system_app = None
|
|
self._serve_config: ServeConfig = None
|
|
self._dao: ServeDao = dao
|
|
self._flow_factory: FlowFactory = FlowFactory()
|
|
self._dbgpts_loader: Optional[DBGPTsLoader] = None
|
|
|
|
super().__init__(system_app)
|
|
|
|
def init_app(self, system_app: SystemApp) -> None:
|
|
"""Initialize the service
|
|
|
|
Args:
|
|
system_app (SystemApp): The system app
|
|
"""
|
|
super().init_app(system_app)
|
|
|
|
self._serve_config = ServeConfig.from_app_config(
|
|
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
|
)
|
|
self._dao = self._dao or ServeDao(self._serve_config)
|
|
self._system_app = system_app
|
|
self._dbgpts_loader = system_app.get_component(
|
|
DBGPTsLoader.name,
|
|
DBGPTsLoader,
|
|
or_register_component=DBGPTsLoader,
|
|
load_dbgpts_interval=self._serve_config.load_dbgpts_interval,
|
|
)
|
|
|
|
def before_start(self):
|
|
"""Execute before the application starts"""
|
|
super().before_start()
|
|
self._pre_load_dag_from_db()
|
|
self._pre_load_dag_from_dbgpts()
|
|
|
|
def after_start(self):
|
|
"""Execute after the application starts"""
|
|
self.load_dag_from_db()
|
|
self.load_dag_from_dbgpts(is_first_load=True)
|
|
schedule.every(self._serve_config.load_dbgpts_interval).seconds.do(
|
|
self.load_dag_from_dbgpts
|
|
)
|
|
|
|
@property
|
|
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
|
|
"""Returns the internal DAO."""
|
|
return self._dao
|
|
|
|
@property
|
|
def dbgpts_loader(self) -> DBGPTsLoader:
|
|
"""Returns the internal DBGPTsLoader."""
|
|
if self._dbgpts_loader is None:
|
|
raise ValueError("DBGPTsLoader is not initialized")
|
|
return self._dbgpts_loader
|
|
|
|
@property
|
|
def config(self) -> ServeConfig:
|
|
"""Returns the internal ServeConfig."""
|
|
return self._serve_config
|
|
|
|
def create(self, request: ServeRequest) -> ServerResponse:
|
|
"""Create a new Flow entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
|
|
def create_and_save_dag(
|
|
self, request: ServeRequest, save_failed_flow: bool = False
|
|
) -> ServerResponse:
|
|
"""Create a new Flow entity and save the DAG
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
save_failed_flow (bool): Whether to save the failed flow
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
try:
|
|
# Build DAG from request
|
|
if request.define_type == "json":
|
|
dag = self._flow_factory.build(request)
|
|
else:
|
|
dag = request.flow_dag
|
|
request.dag_id = dag.dag_id
|
|
# Save DAG to storage
|
|
request.flow_category = self._parse_flow_category(dag)
|
|
except Exception as e:
|
|
if save_failed_flow:
|
|
request.state = State.LOAD_FAILED
|
|
request.error_message = str(e)
|
|
request.dag_id = ""
|
|
return self.dao.create(request)
|
|
else:
|
|
raise ValueError(
|
|
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
|
) from e
|
|
self.dao.create(request)
|
|
# Query from database
|
|
res = self.get({"uid": request.uid})
|
|
|
|
state = request.state
|
|
try:
|
|
if state == State.DEPLOYED:
|
|
# Register the DAG
|
|
self.dag_manager.register_dag(dag, request.uid)
|
|
# Update state to RUNNING
|
|
request.state = State.RUNNING
|
|
request.error_message = ""
|
|
self.dao.update({"uid": request.uid}, request)
|
|
else:
|
|
logger.info(f"Flow state is {state}, skip register DAG")
|
|
except Exception as e:
|
|
logger.warning(f"Register DAG({dag.dag_id}) error: {str(e)}")
|
|
if save_failed_flow:
|
|
request.state = State.LOAD_FAILED
|
|
request.error_message = f"Register DAG error: {str(e)}"
|
|
request.dag_id = ""
|
|
self.dao.update({"uid": request.uid}, request)
|
|
else:
|
|
# Rollback
|
|
self.delete(request.uid)
|
|
raise e
|
|
return res
|
|
|
|
def _pre_load_dag_from_db(self):
|
|
"""Pre load DAG from db"""
|
|
entities = self.dao.get_list({})
|
|
for entity in entities:
|
|
try:
|
|
self._flow_factory.pre_load_requirements(entity)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Pre load requirements for DAG({entity.name}, {entity.dag_id}) "
|
|
f"from db error: {str(e)}"
|
|
)
|
|
|
|
def load_dag_from_db(self):
|
|
"""Load DAG from db"""
|
|
entities = self.dao.get_list({})
|
|
for entity in entities:
|
|
try:
|
|
if entity.define_type != "json":
|
|
continue
|
|
dag = self._flow_factory.build(entity)
|
|
if entity.state in [State.DEPLOYED, State.RUNNING] or (
|
|
entity.version == "0.1.0" and entity.state == State.INITIALIZING
|
|
):
|
|
# Register the DAG
|
|
self.dag_manager.register_dag(dag, entity.uid)
|
|
# Update state to RUNNING
|
|
entity.state = State.RUNNING
|
|
entity.error_message = ""
|
|
self.dao.update({"uid": entity.uid}, entity)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Load DAG({entity.name}, {entity.dag_id}) from db error: {str(e)}"
|
|
)
|
|
|
|
def _pre_load_dag_from_dbgpts(self):
|
|
"""Pre load DAG from dbgpts"""
|
|
flows = self.dbgpts_loader.get_flows()
|
|
for flow in flows:
|
|
try:
|
|
if flow.define_type == "json":
|
|
self._flow_factory.pre_load_requirements(flow)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Pre load requirements for DAG({flow.name}) from "
|
|
f"dbgpts error: {str(e)}"
|
|
)
|
|
|
|
def load_dag_from_dbgpts(self, is_first_load: bool = False):
|
|
"""Load DAG from dbgpts"""
|
|
flows = self.dbgpts_loader.get_flows()
|
|
for flow in flows:
|
|
try:
|
|
if flow.define_type == "python" and flow.flow_dag is None:
|
|
continue
|
|
# Set state to DEPLOYED
|
|
flow.state = State.DEPLOYED
|
|
exist_inst = self.dao.get_one({"name": flow.name})
|
|
if not exist_inst:
|
|
self.create_and_save_dag(flow, save_failed_flow=True)
|
|
elif is_first_load or exist_inst.state != State.RUNNING:
|
|
# TODO check version, must be greater than the exist one
|
|
flow.uid = exist_inst.uid
|
|
self.update_flow(flow, check_editable=False, save_failed_flow=True)
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
message = traceback.format_exc()
|
|
logger.warning(
|
|
f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}"
|
|
)
|
|
|
|
def update_flow(
|
|
self,
|
|
request: ServeRequest,
|
|
check_editable: bool = True,
|
|
save_failed_flow: bool = False,
|
|
) -> ServerResponse:
|
|
"""Update a Flow entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
check_editable (bool): Whether to check the editable
|
|
save_failed_flow (bool): Whether to save the failed flow
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
new_state = State.DEPLOYED
|
|
try:
|
|
# Try to build the dag from the request
|
|
if request.define_type == "json":
|
|
dag = self._flow_factory.build(request)
|
|
else:
|
|
dag = request.flow_dag
|
|
request.flow_category = self._parse_flow_category(dag)
|
|
except Exception as e:
|
|
if save_failed_flow:
|
|
request.state = State.LOAD_FAILED
|
|
request.error_message = str(e)
|
|
request.dag_id = ""
|
|
return self.dao.update({"uid": request.uid}, request)
|
|
else:
|
|
raise e
|
|
# Build the query request from the request
|
|
query_request = {"uid": request.uid}
|
|
inst = self.get(query_request)
|
|
if not inst:
|
|
raise HTTPException(status_code=404, detail=f"Flow {request.uid} not found")
|
|
if check_editable and not inst.editable:
|
|
raise HTTPException(
|
|
status_code=403, detail=f"Flow {request.uid} is not editable"
|
|
)
|
|
old_state = inst.state
|
|
if not State.can_change_state(old_state, new_state):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Flow {request.uid} state can't change from {old_state} to "
|
|
f"{new_state}",
|
|
)
|
|
old_data: Optional[ServerResponse] = None
|
|
try:
|
|
update_obj = self.dao.update(query_request, update_request=request)
|
|
old_data = self.delete(request.uid)
|
|
old_data.state = old_state
|
|
if not old_data:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Flow detail {request.uid} not found"
|
|
)
|
|
update_obj.flow_dag = request.flow_dag
|
|
return self.create_and_save_dag(update_obj)
|
|
except Exception as e:
|
|
if old_data and old_data.state == State.RUNNING:
|
|
# Old flow is running, try to recover it
|
|
# first set the state to DEPLOYED
|
|
old_data.state = State.DEPLOYED
|
|
self.create_and_save_dag(old_data)
|
|
raise e
|
|
|
|
def get(self, request: QUERY_SPEC) -> Optional[ServerResponse]:
|
|
"""Get a Flow entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
# TODO: implement your own logic here
|
|
# Build the query request from the request
|
|
query_request = request
|
|
flow = self.dao.get_one(query_request)
|
|
if flow:
|
|
fill_flow_panel(flow)
|
|
metadata = self.dag_manager.get_dag_metadata(
|
|
flow.dag_id, alias_name=flow.uid
|
|
)
|
|
if metadata:
|
|
flow.metadata = metadata.to_dict()
|
|
return flow
|
|
|
|
def delete(self, uid: str) -> Optional[ServerResponse]:
|
|
"""Delete a Flow entity
|
|
|
|
Args:
|
|
uid (str): The uid
|
|
|
|
Returns:
|
|
ServerResponse: The data after deletion
|
|
"""
|
|
|
|
# TODO: implement your own logic here
|
|
# Build the query request from the request
|
|
query_request = {"uid": uid}
|
|
inst = self.get(query_request)
|
|
if inst is None:
|
|
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
|
if inst.state == State.RUNNING and not inst.dag_id:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Running flow {uid}'s dag id not found"
|
|
)
|
|
try:
|
|
if inst.dag_id:
|
|
self.dag_manager.unregister_dag(inst.dag_id)
|
|
except Exception as e:
|
|
logger.warning(f"Unregister DAG({inst.dag_id}) error: {str(e)}")
|
|
self.dao.delete(query_request)
|
|
return inst
|
|
|
|
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
|
|
"""Get a list of Flow entities
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
List[ServerResponse]: The response
|
|
"""
|
|
# TODO: implement your own logic here
|
|
# Build the query request from the request
|
|
query_request = request
|
|
return self.dao.get_list(query_request)
|
|
|
|
def get_list_by_page(
|
|
self, request: QUERY_SPEC, page: int, page_size: int
|
|
) -> PaginationResult[ServerResponse]:
|
|
"""Get a list of Flow entities by page
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
page (int): The page number
|
|
page_size (int): The page size
|
|
|
|
Returns:
|
|
List[ServerResponse]: The response
|
|
"""
|
|
page_result = self.dao.get_list_page(
|
|
request, page, page_size, desc_order_column=ServeEntity.gmt_modified.name
|
|
)
|
|
for item in page_result.items:
|
|
metadata = self.dag_manager.get_dag_metadata(
|
|
item.dag_id, alias_name=item.uid
|
|
)
|
|
if metadata:
|
|
item.metadata = metadata.to_dict()
|
|
return page_result
|
|
|
|
def get_flow_templates(
|
|
self,
|
|
user_name: Optional[str] = None,
|
|
sys_code: Optional[str] = None,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
) -> PaginationResult[ServerResponse]:
|
|
"""Get a list of Flow templates
|
|
|
|
Args:
|
|
user_name (Optional[str]): The user name
|
|
sys_code (Optional[str]): The system code
|
|
page (int): The page number
|
|
page_size (int): The page size
|
|
Returns:
|
|
List[ServerResponse]: The response
|
|
"""
|
|
local_file_templates = self._get_flow_templates_from_files()
|
|
return PaginationResult.build_from_all(local_file_templates, page, page_size)
|
|
|
|
def _get_flow_templates_from_files(self) -> List[ServerResponse]:
|
|
"""Get a list of Flow templates from files"""
|
|
user_lang = self._system_app.config.get_current_lang(default="en")
|
|
# List files in current directory
|
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
template_dir = os.path.join(parent_dir, "templates", user_lang)
|
|
default_template_dir = os.path.join(parent_dir, "templates", "en")
|
|
if not os.path.exists(template_dir):
|
|
template_dir = default_template_dir
|
|
templates = []
|
|
for root, _, files in os.walk(template_dir):
|
|
for file in files:
|
|
if file.endswith(".json"):
|
|
try:
|
|
with open(os.path.join(root, file), "r") as f:
|
|
data = json.load(f)
|
|
templates.append(_parse_flow_template_from_json(data))
|
|
except Exception as e:
|
|
logger.warning(f"Load template {file} error: {str(e)}")
|
|
return templates
|
|
|
|
async def chat_stream_flow_str(
|
|
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
|
) -> AsyncIterator[str]:
|
|
"""Stream chat with the AWEL flow.
|
|
|
|
Args:
|
|
flow_uid (str): The flow uid
|
|
request (CommonLLMHttpRequestBody): The request
|
|
"""
|
|
# Must be non-incremental
|
|
request.incremental = False
|
|
async for output in self.safe_chat_stream_flow(flow_uid, request):
|
|
text = output.text
|
|
if text:
|
|
text = text.replace("\n", "\\n")
|
|
if output.error_code != 0:
|
|
yield f"data:[SERVER_ERROR]{text}\n\n"
|
|
break
|
|
else:
|
|
yield f"data:{text}\n\n"
|
|
|
|
async def chat_stream_openai(
|
|
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
|
) -> AsyncIterator[str]:
|
|
conv_uid = request.conv_uid
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=0,
|
|
delta=DeltaMessage(role="assistant"),
|
|
finish_reason=None,
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=conv_uid, choices=[choice_data], model=request.model
|
|
)
|
|
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
|
|
|
yield f"data: {json_data}\n\n"
|
|
|
|
request.incremental = True
|
|
async for output in self.safe_chat_stream_flow(flow_uid, request):
|
|
if not output.success:
|
|
yield f"data: {json.dumps(output.to_dict(), ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=0,
|
|
delta=DeltaMessage(role="assistant", content=output.text),
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=conv_uid,
|
|
choices=[choice_data],
|
|
model=request.model,
|
|
)
|
|
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
|
yield f"data: {json_data}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
async def safe_chat_flow(
|
|
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
|
) -> ModelOutput:
|
|
"""Chat with the AWEL flow.
|
|
|
|
Args:
|
|
flow_uid (str): The flow uid
|
|
request (CommonLLMHttpRequestBody): The request
|
|
|
|
Returns:
|
|
ModelOutput: The output
|
|
"""
|
|
incremental = request.incremental
|
|
try:
|
|
task = await self._get_callable_task(flow_uid)
|
|
return await safe_chat_with_dag_task(task, request)
|
|
except HTTPException as e:
|
|
return ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
|
except Exception as e:
|
|
return ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
|
|
|
async def safe_chat_stream_flow(
|
|
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
|
) -> AsyncIterator[ModelOutput]:
|
|
"""Stream chat with the AWEL flow.
|
|
|
|
Args:
|
|
flow_uid (str): The flow uid
|
|
request (CommonLLMHttpRequestBody): The request
|
|
|
|
Returns:
|
|
AsyncIterator[ModelOutput]: The output
|
|
"""
|
|
incremental = request.incremental
|
|
try:
|
|
task = await self._get_callable_task(flow_uid)
|
|
async for output in safe_chat_stream_with_dag_task(
|
|
task, request, incremental
|
|
):
|
|
yield output
|
|
except HTTPException as e:
|
|
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
|
except Exception as e:
|
|
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
|
|
|
async def _get_callable_task(
|
|
self,
|
|
flow_uid: str,
|
|
) -> BaseOperator:
|
|
"""Return the callable task.
|
|
|
|
Returns:
|
|
BaseOperator: The callable task
|
|
|
|
Raises:
|
|
HTTPException: If the flow is not found
|
|
ValueError: If the flow is not a chat flow or the leaf node is not found.
|
|
"""
|
|
flow = self.get({"uid": flow_uid})
|
|
if not flow:
|
|
raise HTTPException(status_code=404, detail=f"Flow {flow_uid} not found")
|
|
dag_id = flow.dag_id
|
|
if not dag_id or dag_id not in self.dag_manager.dag_map:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
|
|
)
|
|
dag = self.dag_manager.dag_map[dag_id]
|
|
# if (
|
|
# flow.flow_category != FlowCategory.CHAT_FLOW
|
|
# and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
|
|
# ):
|
|
# raise ValueError(f"Flow {flow_uid} is not a chat flow")
|
|
leaf_nodes = dag.leaf_nodes
|
|
if len(leaf_nodes) != 1:
|
|
raise ValueError("Chat Flow just support one leaf node in dag")
|
|
return cast(BaseOperator, leaf_nodes[0])
|
|
|
|
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
|
"""Parse the flow category
|
|
|
|
Args:
|
|
flow_category (str): The flow category
|
|
|
|
Returns:
|
|
FlowCategory: The flow category
|
|
"""
|
|
from dbgpt.core.awel.flow.base import _get_type_cls
|
|
|
|
triggers = dag.trigger_nodes
|
|
leaf_nodes = dag.leaf_nodes
|
|
if (
|
|
not triggers
|
|
or not leaf_nodes
|
|
or len(leaf_nodes) > 1
|
|
or not isinstance(leaf_nodes[0], BaseOperator)
|
|
):
|
|
return FlowCategory.COMMON
|
|
|
|
leaf_node = cast(BaseOperator, leaf_nodes[0])
|
|
if not leaf_node.metadata or not leaf_node.metadata.outputs:
|
|
return FlowCategory.COMMON
|
|
|
|
common_http_trigger = False
|
|
agent_trigger = False
|
|
for trigger in triggers:
|
|
if isinstance(trigger, CommonLLMHttpTrigger):
|
|
common_http_trigger = True
|
|
break
|
|
|
|
if isinstance(trigger, AgentDummyTrigger):
|
|
agent_trigger = True
|
|
break
|
|
|
|
output = leaf_node.metadata.outputs[0]
|
|
try:
|
|
real_class = _get_type_cls(output.type_cls)
|
|
if agent_trigger:
|
|
return FlowCategory.CHAT_AGENT
|
|
elif common_http_trigger and is_chat_flow_type(real_class, is_class=True):
|
|
return FlowCategory.CHAT_FLOW
|
|
except Exception:
|
|
return FlowCategory.COMMON
|
|
|
|
async def debug_flow(
|
|
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
|
|
) -> AsyncIterator[ModelOutput]:
|
|
"""Debug the flow.
|
|
|
|
Args:
|
|
request (FlowDebugRequest): The request
|
|
default_incremental (Optional[bool]): The default incremental configuration
|
|
|
|
Returns:
|
|
AsyncIterator[ModelOutput]: The output
|
|
"""
|
|
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
|
|
|
dag = await blocking_func_to_async(
|
|
self._system_app,
|
|
self._flow_factory.build,
|
|
request.flow,
|
|
)
|
|
leaf_nodes = dag.leaf_nodes
|
|
if len(leaf_nodes) != 1:
|
|
raise ValueError("Chat Flow just support one leaf node in dag")
|
|
task = cast(BaseOperator, leaf_nodes[0])
|
|
dag_metadata = _parse_metadata(dag)
|
|
# TODO: Run task with variables
|
|
variables = request.variables
|
|
dag_request = request.request
|
|
|
|
if isinstance(request.request, CommonLLMHttpRequestBody):
|
|
incremental = request.request.incremental
|
|
elif isinstance(request.request, dict):
|
|
incremental = request.request.get("incremental", False)
|
|
else:
|
|
raise ValueError("Invalid request type")
|
|
|
|
if default_incremental is not None:
|
|
incremental = default_incremental
|
|
|
|
try:
|
|
async for output in safe_chat_stream_with_dag_task(
|
|
task, dag_request, incremental
|
|
):
|
|
yield output
|
|
except HTTPException as e:
|
|
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
|
except Exception as e:
|
|
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
|
|
|
async def _wrapper_chat_stream_flow_str(
|
|
self, stream_iter: AsyncIterator[ModelOutput]
|
|
) -> AsyncIterator[str]:
|
|
async for output in stream_iter:
|
|
text = output.text
|
|
if text:
|
|
text = text.replace("\n", "\\n")
|
|
if output.error_code != 0:
|
|
yield f"data:[SERVER_ERROR]{text}\n\n"
|
|
break
|
|
else:
|
|
yield f"data:{text}\n\n"
|
|
|
|
|
|
def _parse_flow_template_from_json(json_dict: dict) -> ServerResponse:
|
|
"""Parse the flow from json
|
|
|
|
Args:
|
|
json_dict (dict): The json dict
|
|
|
|
Returns:
|
|
ServerResponse: The flow
|
|
"""
|
|
flow_json = json_dict["flow"]
|
|
flow_json["editable"] = False
|
|
del flow_json["uid"]
|
|
flow_json["state"] = State.INITIALIZING
|
|
flow_json["dag_id"] = None
|
|
return ServerResponse(**flow_json)
|