mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 09:28:42 +00:00
102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
import logging
|
|
import sys
|
|
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
|
|
|
from fastapi import HTTPException, Request
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from dbgpt._private.pydantic import BaseModel, Field
|
|
|
|
if sys.version_info < (3, 11):
|
|
try:
|
|
from exceptiongroup import ExceptionGroup
|
|
except ImportError:
|
|
ExceptionGroup = None
|
|
|
|
if TYPE_CHECKING:
|
|
from fastapi import FastAPI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
T = TypeVar("T")
|
|
|
|
|
|
class Result(BaseModel, Generic[T]):
|
|
"""Common result entity class"""
|
|
|
|
success: bool = Field(
|
|
..., description="Whether it is successful, True: success, False: failure"
|
|
)
|
|
err_code: str | None = Field(None, description="Error code")
|
|
err_msg: str | None = Field(None, description="Error message")
|
|
data: T | None = Field(None, description="Return data")
|
|
|
|
@staticmethod
|
|
def succ(data: T) -> "Result[T]":
|
|
"""Build a successful result entity
|
|
|
|
Args:
|
|
data (T): Return data
|
|
|
|
Returns:
|
|
Result[T]: Result entity
|
|
"""
|
|
return Result(success=True, err_code=None, err_msg=None, data=data)
|
|
|
|
@staticmethod
|
|
def failed(msg: str, err_code: Optional[str] = "E000X") -> "Result[Any]":
|
|
"""Build a failed result entity
|
|
|
|
Args:
|
|
msg (str): Error message
|
|
err_code (Optional[str], optional): Error code. Defaults to "E000X".
|
|
"""
|
|
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
|
|
|
|
|
|
async def validation_exception_handler(
|
|
request: Request, exc: RequestValidationError
|
|
) -> JSONResponse:
|
|
"""Validation exception handler"""
|
|
message = ""
|
|
for error in exc.errors():
|
|
loc = ".".join(list(map(str, error.get("loc"))))
|
|
message += loc + ":" + error.get("msg") + ";"
|
|
res = Result.failed(msg=message, err_code="E0001")
|
|
logger.error(f"validation_exception_handler catch RequestValidationError: {res}")
|
|
return JSONResponse(status_code=400, content=res.dict())
|
|
|
|
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
res = Result.failed(
|
|
msg=exc.detail,
|
|
err_code="E0002",
|
|
)
|
|
logger.error(f"http_exception_handler catch HTTPException: {res}")
|
|
return JSONResponse(status_code=400, content=res.dict())
|
|
|
|
|
|
async def common_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
|
"""Common exception handler"""
|
|
|
|
if ExceptionGroup and isinstance(exc, ExceptionGroup):
|
|
err_strs = []
|
|
for e in exc.exceptions:
|
|
err_strs.append(str(e))
|
|
err_msg = ";".join(err_strs)
|
|
else:
|
|
err_msg = str(exc)
|
|
res = Result.failed(
|
|
msg=err_msg,
|
|
err_code="E0003",
|
|
)
|
|
logger.error(f"common_exception_handler catch Exception: {res}")
|
|
return JSONResponse(status_code=400, content=res.dict())
|
|
|
|
|
|
def add_exception_handler(app: "FastAPI"):
|
|
"""Add exception handler"""
|
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
app.add_exception_handler(HTTPException, http_exception_handler)
|
|
app.add_exception_handler(Exception, common_exception_handler)
|