mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat: Support endpoint placeholder
This commit is contained in:
parent
c67b50052d
commit
bf63a967b5
@ -58,6 +58,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}"
|
||||||
|
|
||||||
|
|
||||||
class AWELHttpError(RuntimeError):
|
class AWELHttpError(RuntimeError):
|
||||||
"""AWEL Http Error."""
|
"""AWEL Http Error."""
|
||||||
@ -465,14 +467,11 @@ class HttpTrigger(Trigger):
|
|||||||
router (APIRouter): The router to mount the trigger.
|
router (APIRouter): The router to mount the trigger.
|
||||||
global_prefix (Optional[str], optional): The global prefix of the router.
|
global_prefix (Optional[str], optional): The global prefix of the router.
|
||||||
"""
|
"""
|
||||||
path = (
|
endpoint = self._resolved_endpoint()
|
||||||
join_paths(global_prefix, self._endpoint)
|
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
|
||||||
if global_prefix
|
|
||||||
else self._endpoint
|
|
||||||
)
|
|
||||||
dynamic_route_function = self._create_route_func()
|
dynamic_route_function = self._create_route_func()
|
||||||
router.api_route(
|
router.api_route(
|
||||||
self._endpoint,
|
endpoint,
|
||||||
methods=self._methods,
|
methods=self._methods,
|
||||||
response_model=self._response_model,
|
response_model=self._response_model,
|
||||||
status_code=self._status_code,
|
status_code=self._status_code,
|
||||||
@ -498,11 +497,9 @@ class HttpTrigger(Trigger):
|
|||||||
"""
|
"""
|
||||||
from dbgpt.util.fastapi import PriorityAPIRouter
|
from dbgpt.util.fastapi import PriorityAPIRouter
|
||||||
|
|
||||||
path = (
|
endpoint = self._resolved_endpoint()
|
||||||
join_paths(global_prefix, self._endpoint)
|
|
||||||
if global_prefix
|
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
|
||||||
else self._endpoint
|
|
||||||
)
|
|
||||||
dynamic_route_function = self._create_route_func()
|
dynamic_route_function = self._create_route_func()
|
||||||
router = cast(PriorityAPIRouter, app.router)
|
router = cast(PriorityAPIRouter, app.router)
|
||||||
router.add_api_route(
|
router.add_api_route(
|
||||||
@ -533,17 +530,28 @@ class HttpTrigger(Trigger):
|
|||||||
"""
|
"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
path = (
|
endpoint = self._resolved_endpoint()
|
||||||
join_paths(global_prefix, self._endpoint)
|
|
||||||
if global_prefix
|
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
|
||||||
else self._endpoint
|
|
||||||
)
|
|
||||||
app_router = cast(APIRouter, app.router)
|
app_router = cast(APIRouter, app.router)
|
||||||
for i, r in enumerate(app_router.routes):
|
for i, r in enumerate(app_router.routes):
|
||||||
if r.path_format == path: # type: ignore
|
if r.path_format == path: # type: ignore
|
||||||
# TODO, remove with path and methods
|
# TODO, remove with path and methods
|
||||||
del app_router.routes[i]
|
del app_router.routes[i]
|
||||||
|
|
||||||
|
def _resolved_endpoint(self) -> str:
|
||||||
|
"""Get the resolved endpoint.
|
||||||
|
|
||||||
|
Replace the placeholder {dag_id} with the real dag_id.
|
||||||
|
"""
|
||||||
|
endpoint = self._endpoint
|
||||||
|
if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint:
|
||||||
|
return endpoint
|
||||||
|
if not self.dag:
|
||||||
|
raise AWELHttpError("DAG is not set")
|
||||||
|
dag_id = self.dag.dag_id
|
||||||
|
return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id)
|
||||||
|
|
||||||
def _trigger_mode(self) -> str:
|
def _trigger_mode(self) -> str:
|
||||||
if (
|
if (
|
||||||
self._req_body
|
self._req_body
|
||||||
@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
parameters=[
|
parameters=[
|
||||||
_PARAMETER_ENDPOINT.new(),
|
Parameter.build_from(
|
||||||
|
_("API Endpoint"),
|
||||||
|
"endpoint",
|
||||||
|
str,
|
||||||
|
optional=True,
|
||||||
|
default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
|
||||||
|
description=_("The API endpoint"),
|
||||||
|
),
|
||||||
_PARAMETER_METHODS_POST_PUT.new(),
|
_PARAMETER_METHODS_POST_PUT.new(),
|
||||||
_PARAMETER_STREAMING_RESPONSE.new(),
|
_PARAMETER_STREAMING_RESPONSE.new(),
|
||||||
_PARAMETER_RESPONSE_BODY.new(),
|
_PARAMETER_RESPONSE_BODY.new(),
|
||||||
@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint: str,
|
endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
|
||||||
methods: Optional[Union[str, List[str]]] = "POST",
|
methods: Optional[Union[str, List[str]]] = "POST",
|
||||||
streaming_response: bool = False,
|
streaming_response: bool = False,
|
||||||
http_response_body: Optional[Type[BaseHttpBody]] = None,
|
http_response_body: Optional[Type[BaseHttpBody]] = None,
|
||||||
|
@ -81,7 +81,8 @@ class HttpTriggerManager(TriggerManager):
|
|||||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||||
trigger_id = trigger.node_id
|
trigger_id = trigger.node_id
|
||||||
if trigger_id not in self._trigger_map:
|
if trigger_id not in self._trigger_map:
|
||||||
path = join_paths(self._router_prefix, trigger._endpoint)
|
real_endpoint = trigger._resolved_endpoint()
|
||||||
|
path = join_paths(self._router_prefix, real_endpoint)
|
||||||
methods = trigger._methods
|
methods = trigger._methods
|
||||||
# Check whether the route is already registered
|
# Check whether the route is already registered
|
||||||
self._register_route_tables(path, methods)
|
self._register_route_tables(path, methods)
|
||||||
@ -116,9 +117,9 @@ class HttpTriggerManager(TriggerManager):
|
|||||||
if not app:
|
if not app:
|
||||||
raise ValueError("System app not initialized")
|
raise ValueError("System app not initialized")
|
||||||
trigger.remove_from_app(app, self._router_prefix)
|
trigger.remove_from_app(app, self._router_prefix)
|
||||||
self._unregister_route_tables(
|
real_endpoint = trigger._resolved_endpoint()
|
||||||
join_paths(self._router_prefix, trigger._endpoint), trigger._methods
|
path = join_paths(self._router_prefix, real_endpoint)
|
||||||
)
|
self._unregister_route_tables(path, trigger._methods)
|
||||||
del self._trigger_map[trigger_id]
|
del self._trigger_map[trigger_id]
|
||||||
|
|
||||||
def _init_app(self, system_app: SystemApp):
|
def _init_app(self, system_app: SystemApp):
|
||||||
|
Loading…
Reference in New Issue
Block a user