diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 56e4dcbf1..fdeb1afbf 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -446,27 +446,25 @@ class DAGContext: def __init__( self, + node_to_outputs: Dict[str, TaskContext], + share_data: Dict[str, Any], streaming_call: bool = False, - node_to_outputs: Optional[Dict[str, TaskContext]] = None, node_name_to_ids: Optional[Dict[str, str]] = None, ) -> None: """Initialize a DAGContext. Args: + node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG. + share_data (Dict[str, Any]): The share data of current DAG. streaming_call (bool, optional): Whether the current DAG is streaming call. Defaults to False. - node_to_outputs (Optional[Dict[str, TaskContext]], optional): - The task outputs of current DAG. Defaults to None. - node_name_to_ids (Optional[Dict[str, str]], optional): - The task name to task id mapping. Defaults to None. + node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node """ - if not node_to_outputs: - node_to_outputs = {} if not node_name_to_ids: node_name_to_ids = {} self._streaming_call = streaming_call self._curr_task_ctx: Optional[TaskContext] = None - self._share_data: Dict[str, Any] = {} + self._share_data: Dict[str, Any] = share_data self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs self._node_name_to_ids: Dict[str, str] = node_name_to_ids @@ -530,6 +528,7 @@ class DAGContext: Returns: Any: The share data, you can cast it to the real type """ + logger.debug(f"Get share data by key {key} from {id(self._share_data)}") return self._share_data.get(key) async def save_to_share_data( @@ -545,6 +544,7 @@ class DAGContext: """ if key in self._share_data and not overwrite: raise ValueError(f"Share data key {key} already exists") + logger.debug(f"Save share data by key {key} to {id(self._share_data)}") self._share_data[key] = data async def get_task_share_data(self, task_name: str, key: str) -> Any: diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index bfd4a8da5..480f3b89a 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -3,7 +3,7 @@ This runner will run the workflow in the current process. """ import logging -from typing import Dict, List, Optional, Set, cast +from typing import Any, Dict, List, Optional, Set, cast from dbgpt.component import SystemApp @@ -20,6 +20,10 @@ logger = logging.getLogger(__name__) class DefaultWorkflowRunner(WorkflowRunner): """The default workflow runner.""" + def __init__(self): + """Init the default workflow runner.""" + self._running_dag_ctx: Dict[str, DAGContext] = {} + async def execute_workflow( self, node: BaseOperator, @@ -44,15 +48,22 @@ class DefaultWorkflowRunner(WorkflowRunner): if not exist_dag_ctx: # Create DAG context node_outputs: Dict[str, TaskContext] = {} + share_data: Dict[str, Any] = {} else: # Share node output with exist dag context node_outputs = exist_dag_ctx._node_to_outputs + share_data = exist_dag_ctx._share_data dag_ctx = DAGContext( - streaming_call=streaming_call, node_to_outputs=node_outputs, + share_data=share_data, + streaming_call=streaming_call, node_name_to_ids=job_manager._node_name_to_ids, ) - logger.info(f"Begin run workflow from end operator, id: {node.node_id}") + if node.dag: + self._running_dag_ctx[node.dag.dag_id] = dag_ctx + logger.info( + f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}" + ) logger.debug(f"Node id {node.node_id}, call_data: {call_data}") skip_node_ids: Set[str] = set() system_app: Optional[SystemApp] = DAGVar.get_current_system_app() @@ -64,7 +75,8 @@ class DefaultWorkflowRunner(WorkflowRunner): if not streaming_call and node.dag: # streaming call not work for dag end await node.dag._after_dag_end() - + if node.dag: + del self._running_dag_ctx[node.dag.dag_id] return dag_ctx async def _execute_node( diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index f803bb945..444a1fd00 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -14,11 +14,21 @@ if TYPE_CHECKING: from starlette.requests import Request RequestBody = Union[Type[Request], Type[BaseModel], Type[str]] - StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool] + StreamingPredictFunc = Callable[[Union[Request, BaseModel, str, None]], bool] logger = logging.getLogger(__name__) +class AWELHttpError(RuntimeError): + """AWEL Http Error.""" + + def __init__(self, msg: str, code: Optional[str] = None): + """Init the AWELHttpError.""" + super().__init__(msg) + self.msg = msg + self.code = code + + class HttpTrigger(Trigger): """Http trigger for AWEL. @@ -65,29 +75,74 @@ class HttpTrigger(Trigger): Args: router (APIRouter): The router to mount the trigger. """ - from fastapi import Depends + from inspect import Parameter, Signature + from typing import get_type_hints + from starlette.requests import Request methods = [self._methods] if isinstance(self._methods, str) else self._methods + is_query_method = ( + all(method in ["GET", "DELETE"] for method in methods) if methods else True + ) + + async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]): + streaming_response = self._streaming_response + if self._streaming_predict_func: + streaming_response = self._streaming_predict_func(body) + dag = self.dag + if not dag: + raise AWELHttpError("DAG is not set") + return await _trigger_dag( + body, + dag, + streaming_response, + self._response_headers, + self._response_media_type, + ) def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]): - async def _request_body_dependency(request: Request): - return await _parse_request_body(request, self._req_body) + async def route_function_request(request: Request): + return await _trigger_dag_func(request) - async def route_function(body=Depends(_request_body_dependency)): - streaming_response = self._streaming_response - if self._streaming_predict_func: - streaming_response = self._streaming_predict_func(body) - return await _trigger_dag( - body, - self.dag, - streaming_response, - self._response_headers, - self._response_media_type, - ) + async def route_function_none(): + return await _trigger_dag_func(None) - route_function.__name__ = name - return route_function + route_function_request.__name__ = name + route_function_none.__name__ = name + + if not req_body_cls: + return route_function_none + if req_body_cls == Request: + return route_function_request + + if is_query_method: + if req_body_cls == str: + raise AWELHttpError(f"Query methods {methods} not support str type") + + async def route_function_get(**kwargs): + body = req_body_cls(**kwargs) + return await _trigger_dag_func(body) + + parameters = [ + Parameter( + name=field_name, + kind=Parameter.KEYWORD_ONLY, + default=Parameter.empty, + annotation=field.outer_type_, + ) + for field_name, field in req_body_cls.__fields__.items() + ] + route_function_get.__signature__ = Signature(parameters) # type: ignore + route_function_get.__annotations__ = get_type_hints(req_body_cls) + route_function_get.__name__ = name + return route_function_get + else: + + async def route_function(body: req_body_cls): # type: ignore + return await _trigger_dag_func(body) + + route_function.__name__ = name + return route_function function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}" request_model = ( @@ -111,32 +166,6 @@ class HttpTrigger(Trigger): )(dynamic_route_function) -async def _parse_request_body( - request: "Request", request_body_cls: Optional["RequestBody"] -): - from starlette.requests import Request - - if not request_body_cls: - return None - if request_body_cls == Request: - return request - if request.method == "POST": - if request_body_cls == str: - bytes_body = await request.body() - str_body = bytes_body.decode("utf-8") - return str_body - elif issubclass(request_body_cls, BaseModel): - json_data = await request.json() - return request_body_cls(**json_data) - else: - raise ValueError(f"Invalid request body cls: {request_body_cls}") - elif request.method == "GET": - if issubclass(request_body_cls, BaseModel): - return request_body_cls(**request.query_params) - else: - raise ValueError(f"Invalid request body cls: {request_body_cls}") - - async def _trigger_dag( body: Any, dag: DAG,