feat: Support endpoint placeholder

This commit is contained in:
Fangyin Cheng 2024-08-28 16:48:50 +08:00
parent c67b50052d
commit bf63a967b5
2 changed files with 38 additions and 22 deletions

View File

@ -58,6 +58,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}"
class AWELHttpError(RuntimeError): class AWELHttpError(RuntimeError):
"""AWEL Http Error.""" """AWEL Http Error."""
@ -465,14 +467,11 @@ class HttpTrigger(Trigger):
router (APIRouter): The router to mount the trigger. router (APIRouter): The router to mount the trigger.
global_prefix (Optional[str], optional): The global prefix of the router. global_prefix (Optional[str], optional): The global prefix of the router.
""" """
path = ( endpoint = self._resolved_endpoint()
join_paths(global_prefix, self._endpoint) path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
if global_prefix
else self._endpoint
)
dynamic_route_function = self._create_route_func() dynamic_route_function = self._create_route_func()
router.api_route( router.api_route(
self._endpoint, endpoint,
methods=self._methods, methods=self._methods,
response_model=self._response_model, response_model=self._response_model,
status_code=self._status_code, status_code=self._status_code,
@ -498,11 +497,9 @@ class HttpTrigger(Trigger):
""" """
from dbgpt.util.fastapi import PriorityAPIRouter from dbgpt.util.fastapi import PriorityAPIRouter
path = ( endpoint = self._resolved_endpoint()
join_paths(global_prefix, self._endpoint)
if global_prefix path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
else self._endpoint
)
dynamic_route_function = self._create_route_func() dynamic_route_function = self._create_route_func()
router = cast(PriorityAPIRouter, app.router) router = cast(PriorityAPIRouter, app.router)
router.add_api_route( router.add_api_route(
@ -533,17 +530,28 @@ class HttpTrigger(Trigger):
""" """
from fastapi import APIRouter from fastapi import APIRouter
path = ( endpoint = self._resolved_endpoint()
join_paths(global_prefix, self._endpoint)
if global_prefix path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
else self._endpoint
)
app_router = cast(APIRouter, app.router) app_router = cast(APIRouter, app.router)
for i, r in enumerate(app_router.routes): for i, r in enumerate(app_router.routes):
if r.path_format == path: # type: ignore if r.path_format == path: # type: ignore
# TODO, remove with path and methods # TODO, remove with path and methods
del app_router.routes[i] del app_router.routes[i]
def _resolved_endpoint(self) -> str:
"""Get the resolved endpoint.
Replace the placeholder {dag_id} with the real dag_id.
"""
endpoint = self._endpoint
if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint:
return endpoint
if not self.dag:
raise AWELHttpError("DAG is not set")
dag_id = self.dag.dag_id
return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id)
def _trigger_mode(self) -> str: def _trigger_mode(self) -> str:
if ( if (
self._req_body self._req_body
@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger):
), ),
], ],
parameters=[ parameters=[
_PARAMETER_ENDPOINT.new(), Parameter.build_from(
_("API Endpoint"),
"endpoint",
str,
optional=True,
default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
description=_("The API endpoint"),
),
_PARAMETER_METHODS_POST_PUT.new(), _PARAMETER_METHODS_POST_PUT.new(),
_PARAMETER_STREAMING_RESPONSE.new(), _PARAMETER_STREAMING_RESPONSE.new(),
_PARAMETER_RESPONSE_BODY.new(), _PARAMETER_RESPONSE_BODY.new(),
@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
def __init__( def __init__(
self, self,
endpoint: str, endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
methods: Optional[Union[str, List[str]]] = "POST", methods: Optional[Union[str, List[str]]] = "POST",
streaming_response: bool = False, streaming_response: bool = False,
http_response_body: Optional[Type[BaseHttpBody]] = None, http_response_body: Optional[Type[BaseHttpBody]] = None,

View File

@ -81,7 +81,8 @@ class HttpTriggerManager(TriggerManager):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger") raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger_id = trigger.node_id trigger_id = trigger.node_id
if trigger_id not in self._trigger_map: if trigger_id not in self._trigger_map:
path = join_paths(self._router_prefix, trigger._endpoint) real_endpoint = trigger._resolved_endpoint()
path = join_paths(self._router_prefix, real_endpoint)
methods = trigger._methods methods = trigger._methods
# Check whether the route is already registered # Check whether the route is already registered
self._register_route_tables(path, methods) self._register_route_tables(path, methods)
@ -116,9 +117,9 @@ class HttpTriggerManager(TriggerManager):
if not app: if not app:
raise ValueError("System app not initialized") raise ValueError("System app not initialized")
trigger.remove_from_app(app, self._router_prefix) trigger.remove_from_app(app, self._router_prefix)
self._unregister_route_tables( real_endpoint = trigger._resolved_endpoint()
join_paths(self._router_prefix, trigger._endpoint), trigger._methods path = join_paths(self._router_prefix, real_endpoint)
) self._unregister_route_tables(path, trigger._methods)
del self._trigger_map[trigger_id] del self._trigger_map[trigger_id]
def _init_app(self, system_app: SystemApp): def _init_app(self, system_app: SystemApp):