mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-18 07:30:40 +00:00
feat(web): copy awel flow (#1200)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from dbgpt.core.awel import (
|
||||
CommonLLMHttpResponseBody,
|
||||
)
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory, State
|
||||
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.serve.core import BaseService
|
||||
@@ -103,14 +103,55 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# Build DAG from request
|
||||
dag = self._flow_factory.build(request)
|
||||
request.dag_id = dag.dag_id
|
||||
# Save DAG to storage
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
|
||||
def create_and_save_dag(
|
||||
self, request: ServeRequest, save_failed_flow: bool = False
|
||||
) -> ServerResponse:
|
||||
"""Create a new Flow entity and save the DAG
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
save_failed_flow (bool): Whether to save the failed flow
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
try:
|
||||
# Build DAG from request
|
||||
dag = self._flow_factory.build(request)
|
||||
request.dag_id = dag.dag_id
|
||||
# Save DAG to storage
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
except Exception as e:
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = str(e)
|
||||
return self.dao.create(request)
|
||||
else:
|
||||
raise e
|
||||
res = self.dao.create(request)
|
||||
# Register the DAG
|
||||
self.dag_manager.register_dag(dag)
|
||||
|
||||
state = request.state
|
||||
try:
|
||||
if state == State.DEPLOYED:
|
||||
# Register the DAG
|
||||
self.dag_manager.register_dag(dag)
|
||||
# Update state to RUNNING
|
||||
request.state = State.RUNNING
|
||||
request.error_message = ""
|
||||
self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
logger.info(f"Flow state is {state}, skip register DAG")
|
||||
except Exception as e:
|
||||
logger.warning(f"Register DAG({dag.dag_id}) error: {str(e)}")
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = f"Register DAG error: {str(e)}"
|
||||
self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
# Rollback
|
||||
self.delete(request.uid)
|
||||
raise e
|
||||
return res
|
||||
|
||||
def _pre_load_dag_from_db(self):
|
||||
@@ -131,7 +172,15 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
for entity in entities:
|
||||
try:
|
||||
dag = self._flow_factory.build(entity)
|
||||
self.dag_manager.register_dag(dag)
|
||||
if entity.state in [State.DEPLOYED, State.RUNNING] or (
|
||||
entity.version == "0.1.0" and entity.state == State.INITIALIZING
|
||||
):
|
||||
# Register the DAG
|
||||
self.dag_manager.register_dag(dag)
|
||||
# Update state to RUNNING
|
||||
entity.state = State.RUNNING
|
||||
entity.error_message = ""
|
||||
self.dao.update({"uid": entity.uid}, entity)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Load DAG({entity.name}, {entity.dag_id}) from db error: {str(e)}"
|
||||
@@ -154,36 +203,48 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
try:
|
||||
# Try to build the dag from the request
|
||||
self._flow_factory.build(flow)
|
||||
# Set state to DEPLOYED
|
||||
flow.state = State.DEPLOYED
|
||||
exist_inst = self.get({"name": flow.name})
|
||||
if not exist_inst:
|
||||
self.create(flow)
|
||||
self.create_and_save_dag(flow, save_failed_flow=True)
|
||||
else:
|
||||
# TODO check version, must be greater than the exist one
|
||||
flow.uid = exist_inst.uid
|
||||
self.update(flow, check_editable=False)
|
||||
self.update_flow(flow, check_editable=False, save_failed_flow=True)
|
||||
except Exception as e:
|
||||
message = traceback.format_exc()
|
||||
logger.warning(
|
||||
f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}"
|
||||
)
|
||||
|
||||
def update(
|
||||
self, request: ServeRequest, check_editable: bool = True
|
||||
def update_flow(
|
||||
self,
|
||||
request: ServeRequest,
|
||||
check_editable: bool = True,
|
||||
save_failed_flow: bool = False,
|
||||
) -> ServerResponse:
|
||||
"""Update a Flow entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
check_editable (bool): Whether to check the editable
|
||||
|
||||
save_failed_flow (bool): Whether to save the failed flow
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# Try to build the dag from the request
|
||||
dag = self._flow_factory.build(request)
|
||||
|
||||
new_state = request.state
|
||||
try:
|
||||
# Try to build the dag from the request
|
||||
dag = self._flow_factory.build(request)
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
except Exception as e:
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = str(e)
|
||||
return self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
raise e
|
||||
# Build the query request from the request
|
||||
query_request = {"uid": request.uid}
|
||||
inst = self.get(query_request)
|
||||
@@ -193,19 +254,26 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Flow {request.uid} is not editable"
|
||||
)
|
||||
old_state = inst.state
|
||||
if not State.can_change_state(old_state, new_state):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Flow {request.uid} state can't change from {old_state} to "
|
||||
f"{new_state}",
|
||||
)
|
||||
old_data: Optional[ServerResponse] = None
|
||||
try:
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
update_obj = self.dao.update(query_request, update_request=request)
|
||||
old_data = self.delete(request.uid)
|
||||
old_data.state = old_state
|
||||
if not old_data:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow detail {request.uid} not found"
|
||||
)
|
||||
return self.create(update_obj)
|
||||
return self.create_and_save_dag(update_obj)
|
||||
except Exception as e:
|
||||
if old_data:
|
||||
self.create(old_data)
|
||||
self.create_and_save_dag(old_data)
|
||||
raise e
|
||||
|
||||
def get(self, request: QUERY_SPEC) -> Optional[ServerResponse]:
|
||||
|
Reference in New Issue
Block a user