mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 05:47:47 +00:00
131 lines
3.8 KiB
Python
131 lines
3.8 KiB
Python
"""FastAPI utilities."""
|
|
|
|
import importlib.metadata as metadata
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.routing import APIRouter
|
|
|
|
_FASTAPI_VERSION = metadata.version("fastapi")
|
|
|
|
|
|
class PriorityAPIRouter(APIRouter):
|
|
"""A router with priority.
|
|
|
|
The route with higher priority will be put in the front of the route list.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""Init a PriorityAPIRouter."""
|
|
super().__init__(*args, **kwargs)
|
|
self.route_priority: Dict[str, int] = {}
|
|
|
|
def add_api_route(
|
|
self, path: str, endpoint: Callable, *, priority: int = 0, **kwargs: Any
|
|
):
|
|
"""Add a route with priority.
|
|
|
|
Args:
|
|
path (str): The path of the route.
|
|
endpoint (Callable): The endpoint of the route.
|
|
priority (int, optional): The priority of the route. Defaults to 0.
|
|
**kwargs (Any): Other arguments.
|
|
"""
|
|
super().add_api_route(path, endpoint, **kwargs)
|
|
self.route_priority[path] = priority
|
|
# Sort the routes by priority.
|
|
self.sort_routes_by_priority()
|
|
|
|
def sort_routes_by_priority(self):
|
|
"""Sort the routes by priority."""
|
|
|
|
def my_func(route):
|
|
if route.path in ["", "/"]:
|
|
return -100
|
|
return self.route_priority.get(route.path, 0)
|
|
|
|
self.routes.sort(key=my_func, reverse=True)
|
|
|
|
|
|
_HAS_STARTUP = False
|
|
_HAS_SHUTDOWN = False
|
|
_GLOBAL_STARTUP_HANDLERS: List[Callable] = []
|
|
|
|
_GLOBAL_SHUTDOWN_HANDLERS: List[Callable] = []
|
|
|
|
|
|
def register_event_handler(app: FastAPI, event: str, handler: Callable):
|
|
"""Register an event handler.
|
|
|
|
Args:
|
|
app (FastAPI): The FastAPI app.
|
|
event (str): The event type.
|
|
handler (Callable): The handler function.
|
|
|
|
"""
|
|
if _FASTAPI_VERSION >= "0.109.1":
|
|
# https://fastapi.tiangolo.com/release-notes/#01091
|
|
if event == "startup":
|
|
if _HAS_STARTUP:
|
|
raise ValueError(
|
|
"FastAPI app already started. Cannot add startup handler."
|
|
)
|
|
_GLOBAL_STARTUP_HANDLERS.append(handler)
|
|
elif event == "shutdown":
|
|
if _HAS_SHUTDOWN:
|
|
raise ValueError(
|
|
"FastAPI app already shutdown. Cannot add shutdown handler."
|
|
)
|
|
_GLOBAL_SHUTDOWN_HANDLERS.append(handler)
|
|
else:
|
|
raise ValueError(f"Invalid event: {event}")
|
|
else:
|
|
if event == "startup":
|
|
app.add_event_handler("startup", handler)
|
|
elif event == "shutdown":
|
|
app.add_event_handler("shutdown", handler)
|
|
else:
|
|
raise ValueError(f"Invalid event: {event}")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Trigger the startup event.
|
|
global _HAS_STARTUP, _HAS_SHUTDOWN
|
|
for handler in _GLOBAL_STARTUP_HANDLERS:
|
|
await handler()
|
|
_HAS_STARTUP = True
|
|
yield
|
|
# Trigger the shutdown event.
|
|
for handler in _GLOBAL_SHUTDOWN_HANDLERS:
|
|
await handler()
|
|
_HAS_SHUTDOWN = True
|
|
|
|
|
|
def create_app(*args, **kwargs) -> FastAPI:
|
|
"""Create a FastAPI app."""
|
|
_sp = None
|
|
if _FASTAPI_VERSION >= "0.109.1":
|
|
if "lifespan" not in kwargs:
|
|
kwargs["lifespan"] = lifespan
|
|
_sp = kwargs["lifespan"]
|
|
app = FastAPI(*args, **kwargs)
|
|
if _sp:
|
|
app.__dbgpt_custom_lifespan = _sp
|
|
return app
|
|
|
|
|
|
def replace_router(app: FastAPI, router: Optional[APIRouter] = None):
|
|
"""Replace the router of the FastAPI app."""
|
|
if not router:
|
|
router = PriorityAPIRouter()
|
|
if _FASTAPI_VERSION >= "0.109.1":
|
|
if hasattr(app, "__dbgpt_custom_lifespan"):
|
|
_sp = getattr(app, "__dbgpt_custom_lifespan")
|
|
router.lifespan_context = _sp
|
|
|
|
app.router = router
|
|
app.setup()
|
|
return app
|