DB-GPT/dbgpt/util/fastapi.py
2024-04-20 09:41:16 +08:00

131 lines
3.8 KiB
Python

"""FastAPI utilities."""
import importlib.metadata as metadata
from contextlib import asynccontextmanager
from typing import Any, Callable, Dict, List, Optional
from fastapi import FastAPI
from fastapi.routing import APIRouter
_FASTAPI_VERSION = metadata.version("fastapi")
class PriorityAPIRouter(APIRouter):
"""A router with priority.
The route with higher priority will be put in the front of the route list.
"""
def __init__(self, *args, **kwargs):
"""Init a PriorityAPIRouter."""
super().__init__(*args, **kwargs)
self.route_priority: Dict[str, int] = {}
def add_api_route(
self, path: str, endpoint: Callable, *, priority: int = 0, **kwargs: Any
):
"""Add a route with priority.
Args:
path (str): The path of the route.
endpoint (Callable): The endpoint of the route.
priority (int, optional): The priority of the route. Defaults to 0.
**kwargs (Any): Other arguments.
"""
super().add_api_route(path, endpoint, **kwargs)
self.route_priority[path] = priority
# Sort the routes by priority.
self.sort_routes_by_priority()
def sort_routes_by_priority(self):
"""Sort the routes by priority."""
def my_func(route):
if route.path in ["", "/"]:
return -100
return self.route_priority.get(route.path, 0)
self.routes.sort(key=my_func, reverse=True)
_HAS_STARTUP = False
_HAS_SHUTDOWN = False
_GLOBAL_STARTUP_HANDLERS: List[Callable] = []
_GLOBAL_SHUTDOWN_HANDLERS: List[Callable] = []
def register_event_handler(app: FastAPI, event: str, handler: Callable):
"""Register an event handler.
Args:
app (FastAPI): The FastAPI app.
event (str): The event type.
handler (Callable): The handler function.
"""
if _FASTAPI_VERSION >= "0.109.1":
# https://fastapi.tiangolo.com/release-notes/#01091
if event == "startup":
if _HAS_STARTUP:
raise ValueError(
"FastAPI app already started. Cannot add startup handler."
)
_GLOBAL_STARTUP_HANDLERS.append(handler)
elif event == "shutdown":
if _HAS_SHUTDOWN:
raise ValueError(
"FastAPI app already shutdown. Cannot add shutdown handler."
)
_GLOBAL_SHUTDOWN_HANDLERS.append(handler)
else:
raise ValueError(f"Invalid event: {event}")
else:
if event == "startup":
app.add_event_handler("startup", handler)
elif event == "shutdown":
app.add_event_handler("shutdown", handler)
else:
raise ValueError(f"Invalid event: {event}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Trigger the startup event.
global _HAS_STARTUP, _HAS_SHUTDOWN
for handler in _GLOBAL_STARTUP_HANDLERS:
await handler()
_HAS_STARTUP = True
yield
# Trigger the shutdown event.
for handler in _GLOBAL_SHUTDOWN_HANDLERS:
await handler()
_HAS_SHUTDOWN = True
def create_app(*args, **kwargs) -> FastAPI:
"""Create a FastAPI app."""
_sp = None
if _FASTAPI_VERSION >= "0.109.1":
if "lifespan" not in kwargs:
kwargs["lifespan"] = lifespan
_sp = kwargs["lifespan"]
app = FastAPI(*args, **kwargs)
if _sp:
app.__dbgpt_custom_lifespan = _sp
return app
def replace_router(app: FastAPI, router: Optional[APIRouter] = None):
"""Replace the router of the FastAPI app."""
if not router:
router = PriorityAPIRouter()
if _FASTAPI_VERSION >= "0.109.1":
if hasattr(app, "__dbgpt_custom_lifespan"):
_sp = getattr(app, "__dbgpt_custom_lifespan")
router.lifespan_context = _sp
app.router = router
app.setup()
return app