feat: Support endpoint placeholder

This commit is contained in:
Fangyin Cheng 2024-08-28 16:48:50 +08:00
parent c67b50052d
commit bf63a967b5
2 changed files with 38 additions and 22 deletions

View File

@ -58,6 +58,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}"
class AWELHttpError(RuntimeError):
"""AWEL Http Error."""
@ -465,14 +467,11 @@ class HttpTrigger(Trigger):
router (APIRouter): The router to mount the trigger.
global_prefix (Optional[str], optional): The global prefix of the router.
"""
path = (
join_paths(global_prefix, self._endpoint)
if global_prefix
else self._endpoint
)
endpoint = self._resolved_endpoint()
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
dynamic_route_function = self._create_route_func()
router.api_route(
self._endpoint,
endpoint,
methods=self._methods,
response_model=self._response_model,
status_code=self._status_code,
@ -498,11 +497,9 @@ class HttpTrigger(Trigger):
"""
from dbgpt.util.fastapi import PriorityAPIRouter
path = (
join_paths(global_prefix, self._endpoint)
if global_prefix
else self._endpoint
)
endpoint = self._resolved_endpoint()
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
dynamic_route_function = self._create_route_func()
router = cast(PriorityAPIRouter, app.router)
router.add_api_route(
@ -533,17 +530,28 @@ class HttpTrigger(Trigger):
"""
from fastapi import APIRouter
path = (
join_paths(global_prefix, self._endpoint)
if global_prefix
else self._endpoint
)
endpoint = self._resolved_endpoint()
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
app_router = cast(APIRouter, app.router)
for i, r in enumerate(app_router.routes):
if r.path_format == path: # type: ignore
# TODO, remove with path and methods
del app_router.routes[i]
def _resolved_endpoint(self) -> str:
"""Get the resolved endpoint.
Replace the placeholder {dag_id} with the real dag_id.
"""
endpoint = self._endpoint
if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint:
return endpoint
if not self.dag:
raise AWELHttpError("DAG is not set")
dag_id = self.dag.dag_id
return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id)
def _trigger_mode(self) -> str:
if (
self._req_body
@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger):
),
],
parameters=[
_PARAMETER_ENDPOINT.new(),
Parameter.build_from(
_("API Endpoint"),
"endpoint",
str,
optional=True,
default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
description=_("The API endpoint"),
),
_PARAMETER_METHODS_POST_PUT.new(),
_PARAMETER_STREAMING_RESPONSE.new(),
_PARAMETER_RESPONSE_BODY.new(),
@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
def __init__(
self,
endpoint: str,
endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
methods: Optional[Union[str, List[str]]] = "POST",
streaming_response: bool = False,
http_response_body: Optional[Type[BaseHttpBody]] = None,

View File

@ -81,7 +81,8 @@ class HttpTriggerManager(TriggerManager):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
path = join_paths(self._router_prefix, trigger._endpoint)
real_endpoint = trigger._resolved_endpoint()
path = join_paths(self._router_prefix, real_endpoint)
methods = trigger._methods
# Check whether the route is already registered
self._register_route_tables(path, methods)
@ -116,9 +117,9 @@ class HttpTriggerManager(TriggerManager):
if not app:
raise ValueError("System app not initialized")
trigger.remove_from_app(app, self._router_prefix)
self._unregister_route_tables(
join_paths(self._router_prefix, trigger._endpoint), trigger._methods
)
real_endpoint = trigger._resolved_endpoint()
path = join_paths(self._router_prefix, real_endpoint)
self._unregister_route_tables(path, trigger._methods)
del self._trigger_map[trigger_id]
def _init_app(self, system_app: SystemApp):