From bf63a967b51ad0ee880071335229d393eaeae809 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 28 Aug 2024 16:48:50 +0800 Subject: [PATCH] feat: Support endpoint placeholder --- dbgpt/core/awel/trigger/http_trigger.py | 51 ++++++++++++++-------- dbgpt/core/awel/trigger/trigger_manager.py | 9 ++-- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 33692a423..6e17be15e 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -58,6 +58,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}" + class AWELHttpError(RuntimeError): """AWEL Http Error.""" @@ -465,14 +467,11 @@ class HttpTrigger(Trigger): router (APIRouter): The router to mount the trigger. global_prefix (Optional[str], optional): The global prefix of the router. """ - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint dynamic_route_function = self._create_route_func() router.api_route( - self._endpoint, + endpoint, methods=self._methods, response_model=self._response_model, status_code=self._status_code, @@ -498,11 +497,9 @@ class HttpTrigger(Trigger): """ from dbgpt.util.fastapi import PriorityAPIRouter - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint dynamic_route_function = self._create_route_func() router = cast(PriorityAPIRouter, app.router) router.add_api_route( @@ -533,17 +530,28 @@ class HttpTrigger(Trigger): """ from fastapi import APIRouter - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint app_router = cast(APIRouter, app.router) for i, r in enumerate(app_router.routes): if r.path_format == path: # type: ignore # TODO, remove with path and methods del app_router.routes[i] + def _resolved_endpoint(self) -> str: + """Get the resolved endpoint. + + Replace the placeholder {dag_id} with the real dag_id. + """ + endpoint = self._endpoint + if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint: + return endpoint + if not self.dag: + raise AWELHttpError("DAG is not set") + dag_id = self.dag.dag_id + return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id) + def _trigger_mode(self) -> str: if ( self._req_body @@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger): ), ], parameters=[ - _PARAMETER_ENDPOINT.new(), + Parameter.build_from( + _("API Endpoint"), + "endpoint", + str, + optional=True, + default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID, + description=_("The API endpoint"), + ), _PARAMETER_METHODS_POST_PUT.new(), _PARAMETER_STREAMING_RESPONSE.new(), _PARAMETER_RESPONSE_BODY.new(), @@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger): def __init__( self, - endpoint: str, + endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID, methods: Optional[Union[str, List[str]]] = "POST", streaming_response: bool = False, http_response_body: Optional[Type[BaseHttpBody]] = None, diff --git a/dbgpt/core/awel/trigger/trigger_manager.py b/dbgpt/core/awel/trigger/trigger_manager.py index 45b040147..94563226e 100644 --- a/dbgpt/core/awel/trigger/trigger_manager.py +++ b/dbgpt/core/awel/trigger/trigger_manager.py @@ -81,7 +81,8 @@ class HttpTriggerManager(TriggerManager): raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger") trigger_id = trigger.node_id if trigger_id not in self._trigger_map: - path = join_paths(self._router_prefix, trigger._endpoint) + real_endpoint = trigger._resolved_endpoint() + path = join_paths(self._router_prefix, real_endpoint) methods = trigger._methods # Check whether the route is already registered self._register_route_tables(path, methods) @@ -116,9 +117,9 @@ class HttpTriggerManager(TriggerManager): if not app: raise ValueError("System app not initialized") trigger.remove_from_app(app, self._router_prefix) - self._unregister_route_tables( - join_paths(self._router_prefix, trigger._endpoint), trigger._methods - ) + real_endpoint = trigger._resolved_endpoint() + path = join_paths(self._router_prefix, real_endpoint) + self._unregister_route_tables(path, trigger._methods) del self._trigger_map[trigger_id] def _init_app(self, system_app: SystemApp):