mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(web): copy awel flow (#1200)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -7,6 +8,12 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
try:
|
||||
from exceptiongroup import ExceptionGroup
|
||||
except ImportError:
|
||||
ExceptionGroup = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
@@ -71,8 +78,16 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
|
||||
async def common_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Common exception handler"""
|
||||
|
||||
if ExceptionGroup and isinstance(exc, ExceptionGroup):
|
||||
err_strs = []
|
||||
for e in exc.exceptions:
|
||||
err_strs.append(str(e))
|
||||
err_msg = ";".join(err_strs)
|
||||
else:
|
||||
err_msg = str(exc)
|
||||
res = Result.failed(
|
||||
msg=str(exc),
|
||||
msg=err_msg,
|
||||
err_code="E0003",
|
||||
)
|
||||
logger.error(f"common_exception_handler catch Exception: {res}")
|
||||
|
@@ -109,7 +109,7 @@ async def create(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.create(request))
|
||||
return Result.succ(service.create_and_save_dag(request))
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -129,7 +129,7 @@ async def update(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update(request))
|
||||
return Result.succ(service.update_flow(request))
|
||||
|
||||
|
||||
@router.delete("/flows/{uid}")
|
||||
|
@@ -28,6 +28,7 @@ class ServeEntity(Model):
|
||||
flow_data = Column(Text, nullable=True, comment="Flow data, JSON format")
|
||||
description = Column(String(512), nullable=True, comment="Flow description")
|
||||
state = Column(String(32), nullable=True, comment="Flow state")
|
||||
error_message = Column(String(512), nullable=True, comment="Error message")
|
||||
source = Column(String(64), nullable=True, comment="Flow source")
|
||||
source_url = Column(String(512), nullable=True, comment="Flow source url")
|
||||
version = Column(String(32), nullable=True, comment="Flow version")
|
||||
@@ -84,6 +85,9 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
request_dict = request.dict() if isinstance(request, ServeRequest) else request
|
||||
flow_data = json.dumps(request_dict.get("flow_data"), ensure_ascii=False)
|
||||
state = request_dict.get("state", State.INITIALIZING.value)
|
||||
error_message = request_dict.get("error_message")
|
||||
if error_message:
|
||||
error_message = error_message[:500]
|
||||
new_dict = {
|
||||
"uid": request_dict.get("uid"),
|
||||
"dag_id": request_dict.get("dag_id"),
|
||||
@@ -92,6 +96,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"flow_category": request_dict.get("flow_category"),
|
||||
"flow_data": flow_data,
|
||||
"state": state,
|
||||
"error_message": error_message,
|
||||
"source": request_dict.get("source"),
|
||||
"source_url": request_dict.get("source_url"),
|
||||
"version": request_dict.get("version"),
|
||||
@@ -121,6 +126,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flow_category=entity.flow_category,
|
||||
flow_data=flow_data,
|
||||
state=State.value_of(entity.state),
|
||||
error_message=entity.error_message,
|
||||
source=entity.source,
|
||||
source_url=entity.source_url,
|
||||
version=entity.version,
|
||||
@@ -151,6 +157,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flow_data=flow_data,
|
||||
description=entity.description,
|
||||
state=State.value_of(entity.state),
|
||||
error_message=entity.error_message,
|
||||
source=entity.source,
|
||||
source_url=entity.source_url,
|
||||
version=entity.version,
|
||||
@@ -183,14 +190,16 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
entry.description = update_request.description
|
||||
if update_request.state:
|
||||
entry.state = update_request.state.value
|
||||
if update_request.error_message is not None:
|
||||
# Keep first 500 characters
|
||||
entry.error_message = update_request.error_message[:500]
|
||||
if update_request.source:
|
||||
entry.source = update_request.source
|
||||
if update_request.source_url:
|
||||
entry.source_url = update_request.source_url
|
||||
if update_request.version:
|
||||
entry.version = update_request.version
|
||||
if update_request.editable:
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
if update_request.user_name:
|
||||
entry.user_name = update_request.user_name
|
||||
if update_request.sys_code:
|
||||
|
@@ -13,7 +13,7 @@ from dbgpt.core.awel import (
|
||||
CommonLLMHttpResponseBody,
|
||||
)
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory, State
|
||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.serve.core import BaseService
|
||||
@@ -103,14 +103,55 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# Build DAG from request
|
||||
dag = self._flow_factory.build(request)
|
||||
request.dag_id = dag.dag_id
|
||||
# Save DAG to storage
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
|
||||
def create_and_save_dag(
|
||||
self, request: ServeRequest, save_failed_flow: bool = False
|
||||
) -> ServerResponse:
|
||||
"""Create a new Flow entity and save the DAG
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
save_failed_flow (bool): Whether to save the failed flow
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
try:
|
||||
# Build DAG from request
|
||||
dag = self._flow_factory.build(request)
|
||||
request.dag_id = dag.dag_id
|
||||
# Save DAG to storage
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
except Exception as e:
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = str(e)
|
||||
return self.dao.create(request)
|
||||
else:
|
||||
raise e
|
||||
res = self.dao.create(request)
|
||||
# Register the DAG
|
||||
self.dag_manager.register_dag(dag)
|
||||
|
||||
state = request.state
|
||||
try:
|
||||
if state == State.DEPLOYED:
|
||||
# Register the DAG
|
||||
self.dag_manager.register_dag(dag)
|
||||
# Update state to RUNNING
|
||||
request.state = State.RUNNING
|
||||
request.error_message = ""
|
||||
self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
logger.info(f"Flow state is {state}, skip register DAG")
|
||||
except Exception as e:
|
||||
logger.warning(f"Register DAG({dag.dag_id}) error: {str(e)}")
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = f"Register DAG error: {str(e)}"
|
||||
self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
# Rollback
|
||||
self.delete(request.uid)
|
||||
raise e
|
||||
return res
|
||||
|
||||
def _pre_load_dag_from_db(self):
|
||||
@@ -131,7 +172,15 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
for entity in entities:
|
||||
try:
|
||||
dag = self._flow_factory.build(entity)
|
||||
self.dag_manager.register_dag(dag)
|
||||
if entity.state in [State.DEPLOYED, State.RUNNING] or (
|
||||
entity.version == "0.1.0" and entity.state == State.INITIALIZING
|
||||
):
|
||||
# Register the DAG
|
||||
self.dag_manager.register_dag(dag)
|
||||
# Update state to RUNNING
|
||||
entity.state = State.RUNNING
|
||||
entity.error_message = ""
|
||||
self.dao.update({"uid": entity.uid}, entity)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Load DAG({entity.name}, {entity.dag_id}) from db error: {str(e)}"
|
||||
@@ -154,36 +203,48 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
try:
|
||||
# Try to build the dag from the request
|
||||
self._flow_factory.build(flow)
|
||||
# Set state to DEPLOYED
|
||||
flow.state = State.DEPLOYED
|
||||
exist_inst = self.get({"name": flow.name})
|
||||
if not exist_inst:
|
||||
self.create(flow)
|
||||
self.create_and_save_dag(flow, save_failed_flow=True)
|
||||
else:
|
||||
# TODO check version, must be greater than the exist one
|
||||
flow.uid = exist_inst.uid
|
||||
self.update(flow, check_editable=False)
|
||||
self.update_flow(flow, check_editable=False, save_failed_flow=True)
|
||||
except Exception as e:
|
||||
message = traceback.format_exc()
|
||||
logger.warning(
|
||||
f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}"
|
||||
)
|
||||
|
||||
def update(
|
||||
self, request: ServeRequest, check_editable: bool = True
|
||||
def update_flow(
|
||||
self,
|
||||
request: ServeRequest,
|
||||
check_editable: bool = True,
|
||||
save_failed_flow: bool = False,
|
||||
) -> ServerResponse:
|
||||
"""Update a Flow entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
check_editable (bool): Whether to check the editable
|
||||
|
||||
save_failed_flow (bool): Whether to save the failed flow
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# Try to build the dag from the request
|
||||
dag = self._flow_factory.build(request)
|
||||
|
||||
new_state = request.state
|
||||
try:
|
||||
# Try to build the dag from the request
|
||||
dag = self._flow_factory.build(request)
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
except Exception as e:
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = str(e)
|
||||
return self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
raise e
|
||||
# Build the query request from the request
|
||||
query_request = {"uid": request.uid}
|
||||
inst = self.get(query_request)
|
||||
@@ -193,19 +254,26 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Flow {request.uid} is not editable"
|
||||
)
|
||||
old_state = inst.state
|
||||
if not State.can_change_state(old_state, new_state):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Flow {request.uid} state can't change from {old_state} to "
|
||||
f"{new_state}",
|
||||
)
|
||||
old_data: Optional[ServerResponse] = None
|
||||
try:
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
update_obj = self.dao.update(query_request, update_request=request)
|
||||
old_data = self.delete(request.uid)
|
||||
old_data.state = old_state
|
||||
if not old_data:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow detail {request.uid} not found"
|
||||
)
|
||||
return self.create(update_obj)
|
||||
return self.create_and_save_dag(update_obj)
|
||||
except Exception as e:
|
||||
if old_data:
|
||||
self.create(old_data)
|
||||
self.create_and_save_dag(old_data)
|
||||
raise e
|
||||
|
||||
def get(self, request: QUERY_SPEC) -> Optional[ServerResponse]:
|
||||
|
Reference in New Issue
Block a user