feat(web): copy awel flow (#1200)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Hzh_97
2024-02-28 21:03:23 +08:00
committed by GitHub
parent 0837da48ba
commit 673ddaab5b
68 changed files with 898 additions and 190 deletions

View File

@@ -1,4 +1,5 @@
import logging
import sys
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from fastapi import HTTPException, Request
@@ -7,6 +8,12 @@ from fastapi.responses import JSONResponse
from dbgpt._private.pydantic import BaseModel, Field
if sys.version_info < (3, 11):
try:
from exceptiongroup import ExceptionGroup
except ImportError:
ExceptionGroup = None
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -71,8 +78,16 @@ async def http_exception_handler(request: Request, exc: HTTPException):
async def common_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Common exception handler"""
if ExceptionGroup and isinstance(exc, ExceptionGroup):
err_strs = []
for e in exc.exceptions:
err_strs.append(str(e))
err_msg = ";".join(err_strs)
else:
err_msg = str(exc)
res = Result.failed(
msg=str(exc),
msg=err_msg,
err_code="E0003",
)
logger.error(f"common_exception_handler catch Exception: {res}")

View File

@@ -109,7 +109,7 @@ async def create(
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
return Result.succ(service.create_and_save_dag(request))
@router.put(
@@ -129,7 +129,7 @@ async def update(
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
return Result.succ(service.update_flow(request))
@router.delete("/flows/{uid}")

View File

@@ -28,6 +28,7 @@ class ServeEntity(Model):
flow_data = Column(Text, nullable=True, comment="Flow data, JSON format")
description = Column(String(512), nullable=True, comment="Flow description")
state = Column(String(32), nullable=True, comment="Flow state")
error_message = Column(String(512), nullable=True, comment="Error message")
source = Column(String(64), nullable=True, comment="Flow source")
source_url = Column(String(512), nullable=True, comment="Flow source url")
version = Column(String(32), nullable=True, comment="Flow version")
@@ -84,6 +85,9 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
request_dict = request.dict() if isinstance(request, ServeRequest) else request
flow_data = json.dumps(request_dict.get("flow_data"), ensure_ascii=False)
state = request_dict.get("state", State.INITIALIZING.value)
error_message = request_dict.get("error_message")
if error_message:
error_message = error_message[:500]
new_dict = {
"uid": request_dict.get("uid"),
"dag_id": request_dict.get("dag_id"),
@@ -92,6 +96,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"flow_category": request_dict.get("flow_category"),
"flow_data": flow_data,
"state": state,
"error_message": error_message,
"source": request_dict.get("source"),
"source_url": request_dict.get("source_url"),
"version": request_dict.get("version"),
@@ -121,6 +126,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
flow_category=entity.flow_category,
flow_data=flow_data,
state=State.value_of(entity.state),
error_message=entity.error_message,
source=entity.source,
source_url=entity.source_url,
version=entity.version,
@@ -151,6 +157,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
flow_data=flow_data,
description=entity.description,
state=State.value_of(entity.state),
error_message=entity.error_message,
source=entity.source,
source_url=entity.source_url,
version=entity.version,
@@ -183,14 +190,16 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
entry.description = update_request.description
if update_request.state:
entry.state = update_request.state.value
if update_request.error_message is not None:
# Keep first 500 characters
entry.error_message = update_request.error_message[:500]
if update_request.source:
entry.source = update_request.source
if update_request.source_url:
entry.source_url = update_request.source_url
if update_request.version:
entry.version = update_request.version
if update_request.editable:
entry.editable = ServeEntity.parse_editable(update_request.editable)
entry.editable = ServeEntity.parse_editable(update_request.editable)
if update_request.user_name:
entry.user_name = update_request.user_name
if update_request.sys_code:

View File

@@ -13,7 +13,7 @@ from dbgpt.core.awel import (
CommonLLMHttpResponseBody,
)
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory, State
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.serve.core import BaseService
@@ -103,14 +103,55 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
Returns:
ServerResponse: The response
"""
# Build DAG from request
dag = self._flow_factory.build(request)
request.dag_id = dag.dag_id
# Save DAG to storage
request.flow_category = self._parse_flow_category(dag)
def create_and_save_dag(
self, request: ServeRequest, save_failed_flow: bool = False
) -> ServerResponse:
"""Create a new Flow entity and save the DAG
Args:
request (ServeRequest): The request
save_failed_flow (bool): Whether to save the failed flow
Returns:
ServerResponse: The response
"""
try:
# Build DAG from request
dag = self._flow_factory.build(request)
request.dag_id = dag.dag_id
# Save DAG to storage
request.flow_category = self._parse_flow_category(dag)
except Exception as e:
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = str(e)
return self.dao.create(request)
else:
raise e
res = self.dao.create(request)
# Register the DAG
self.dag_manager.register_dag(dag)
state = request.state
try:
if state == State.DEPLOYED:
# Register the DAG
self.dag_manager.register_dag(dag)
# Update state to RUNNING
request.state = State.RUNNING
request.error_message = ""
self.dao.update({"uid": request.uid}, request)
else:
logger.info(f"Flow state is {state}, skip register DAG")
except Exception as e:
logger.warning(f"Register DAG({dag.dag_id}) error: {str(e)}")
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = f"Register DAG error: {str(e)}"
self.dao.update({"uid": request.uid}, request)
else:
# Rollback
self.delete(request.uid)
raise e
return res
def _pre_load_dag_from_db(self):
@@ -131,7 +172,15 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
for entity in entities:
try:
dag = self._flow_factory.build(entity)
self.dag_manager.register_dag(dag)
if entity.state in [State.DEPLOYED, State.RUNNING] or (
entity.version == "0.1.0" and entity.state == State.INITIALIZING
):
# Register the DAG
self.dag_manager.register_dag(dag)
# Update state to RUNNING
entity.state = State.RUNNING
entity.error_message = ""
self.dao.update({"uid": entity.uid}, entity)
except Exception as e:
logger.warning(
f"Load DAG({entity.name}, {entity.dag_id}) from db error: {str(e)}"
@@ -154,36 +203,48 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
flows = self.dbgpts_loader.get_flows()
for flow in flows:
try:
# Try to build the dag from the request
self._flow_factory.build(flow)
# Set state to DEPLOYED
flow.state = State.DEPLOYED
exist_inst = self.get({"name": flow.name})
if not exist_inst:
self.create(flow)
self.create_and_save_dag(flow, save_failed_flow=True)
else:
# TODO check version, must be greater than the exist one
flow.uid = exist_inst.uid
self.update(flow, check_editable=False)
self.update_flow(flow, check_editable=False, save_failed_flow=True)
except Exception as e:
message = traceback.format_exc()
logger.warning(
f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}"
)
def update(
self, request: ServeRequest, check_editable: bool = True
def update_flow(
self,
request: ServeRequest,
check_editable: bool = True,
save_failed_flow: bool = False,
) -> ServerResponse:
"""Update a Flow entity
Args:
request (ServeRequest): The request
check_editable (bool): Whether to check the editable
save_failed_flow (bool): Whether to save the failed flow
Returns:
ServerResponse: The response
"""
# Try to build the dag from the request
dag = self._flow_factory.build(request)
new_state = request.state
try:
# Try to build the dag from the request
dag = self._flow_factory.build(request)
request.flow_category = self._parse_flow_category(dag)
except Exception as e:
if save_failed_flow:
request.state = State.LOAD_FAILED
request.error_message = str(e)
return self.dao.update({"uid": request.uid}, request)
else:
raise e
# Build the query request from the request
query_request = {"uid": request.uid}
inst = self.get(query_request)
@@ -193,19 +254,26 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise HTTPException(
status_code=403, detail=f"Flow {request.uid} is not editable"
)
old_state = inst.state
if not State.can_change_state(old_state, new_state):
raise HTTPException(
status_code=400,
detail=f"Flow {request.uid} state can't change from {old_state} to "
f"{new_state}",
)
old_data: Optional[ServerResponse] = None
try:
request.flow_category = self._parse_flow_category(dag)
update_obj = self.dao.update(query_request, update_request=request)
old_data = self.delete(request.uid)
old_data.state = old_state
if not old_data:
raise HTTPException(
status_code=404, detail=f"Flow detail {request.uid} not found"
)
return self.create(update_obj)
return self.create_and_save_dag(update_obj)
except Exception as e:
if old_data:
self.create(old_data)
self.create_and_save_dag(old_data)
raise e
def get(self, request: QUERY_SPEC) -> Optional[ServerResponse]: