mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 11:51:42 +00:00
feat: Support endpoint placeholder
This commit is contained in:
parent
c67b50052d
commit
bf63a967b5
@ -58,6 +58,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}"
|
||||
|
||||
|
||||
class AWELHttpError(RuntimeError):
|
||||
"""AWEL Http Error."""
|
||||
@ -465,14 +467,11 @@ class HttpTrigger(Trigger):
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
global_prefix (Optional[str], optional): The global prefix of the router.
|
||||
"""
|
||||
path = (
|
||||
join_paths(global_prefix, self._endpoint)
|
||||
if global_prefix
|
||||
else self._endpoint
|
||||
)
|
||||
endpoint = self._resolved_endpoint()
|
||||
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
|
||||
dynamic_route_function = self._create_route_func()
|
||||
router.api_route(
|
||||
self._endpoint,
|
||||
endpoint,
|
||||
methods=self._methods,
|
||||
response_model=self._response_model,
|
||||
status_code=self._status_code,
|
||||
@ -498,11 +497,9 @@ class HttpTrigger(Trigger):
|
||||
"""
|
||||
from dbgpt.util.fastapi import PriorityAPIRouter
|
||||
|
||||
path = (
|
||||
join_paths(global_prefix, self._endpoint)
|
||||
if global_prefix
|
||||
else self._endpoint
|
||||
)
|
||||
endpoint = self._resolved_endpoint()
|
||||
|
||||
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
|
||||
dynamic_route_function = self._create_route_func()
|
||||
router = cast(PriorityAPIRouter, app.router)
|
||||
router.add_api_route(
|
||||
@ -533,17 +530,28 @@ class HttpTrigger(Trigger):
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
path = (
|
||||
join_paths(global_prefix, self._endpoint)
|
||||
if global_prefix
|
||||
else self._endpoint
|
||||
)
|
||||
endpoint = self._resolved_endpoint()
|
||||
|
||||
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
|
||||
app_router = cast(APIRouter, app.router)
|
||||
for i, r in enumerate(app_router.routes):
|
||||
if r.path_format == path: # type: ignore
|
||||
# TODO, remove with path and methods
|
||||
del app_router.routes[i]
|
||||
|
||||
def _resolved_endpoint(self) -> str:
|
||||
"""Get the resolved endpoint.
|
||||
|
||||
Replace the placeholder {dag_id} with the real dag_id.
|
||||
"""
|
||||
endpoint = self._endpoint
|
||||
if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint:
|
||||
return endpoint
|
||||
if not self.dag:
|
||||
raise AWELHttpError("DAG is not set")
|
||||
dag_id = self.dag.dag_id
|
||||
return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id)
|
||||
|
||||
def _trigger_mode(self) -> str:
|
||||
if (
|
||||
self._req_body
|
||||
@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
||||
),
|
||||
],
|
||||
parameters=[
|
||||
_PARAMETER_ENDPOINT.new(),
|
||||
Parameter.build_from(
|
||||
_("API Endpoint"),
|
||||
"endpoint",
|
||||
str,
|
||||
optional=True,
|
||||
default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
|
||||
description=_("The API endpoint"),
|
||||
),
|
||||
_PARAMETER_METHODS_POST_PUT.new(),
|
||||
_PARAMETER_STREAMING_RESPONSE.new(),
|
||||
_PARAMETER_RESPONSE_BODY.new(),
|
||||
@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
|
||||
methods: Optional[Union[str, List[str]]] = "POST",
|
||||
streaming_response: bool = False,
|
||||
http_response_body: Optional[Type[BaseHttpBody]] = None,
|
||||
|
@ -81,7 +81,8 @@ class HttpTriggerManager(TriggerManager):
|
||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||
trigger_id = trigger.node_id
|
||||
if trigger_id not in self._trigger_map:
|
||||
path = join_paths(self._router_prefix, trigger._endpoint)
|
||||
real_endpoint = trigger._resolved_endpoint()
|
||||
path = join_paths(self._router_prefix, real_endpoint)
|
||||
methods = trigger._methods
|
||||
# Check whether the route is already registered
|
||||
self._register_route_tables(path, methods)
|
||||
@ -116,9 +117,9 @@ class HttpTriggerManager(TriggerManager):
|
||||
if not app:
|
||||
raise ValueError("System app not initialized")
|
||||
trigger.remove_from_app(app, self._router_prefix)
|
||||
self._unregister_route_tables(
|
||||
join_paths(self._router_prefix, trigger._endpoint), trigger._methods
|
||||
)
|
||||
real_endpoint = trigger._resolved_endpoint()
|
||||
path = join_paths(self._router_prefix, real_endpoint)
|
||||
self._unregister_route_tables(path, trigger._methods)
|
||||
del self._trigger_map[trigger_id]
|
||||
|
||||
def _init_app(self, system_app: SystemApp):
|
||||
|
Loading…
Reference in New Issue
Block a user