mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 13:27:46 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
306 lines
7.6 KiB
Python
306 lines
7.6 KiB
Python
import logging
|
|
from functools import cache
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from dbgpt.component import SystemApp
|
|
from dbgpt.serve.core import Result
|
|
from dbgpt.util import PaginationResult
|
|
|
|
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
|
from ..service.service import Service
|
|
from .schemas import (
|
|
PromptDebugInput,
|
|
PromptType,
|
|
PromptVerifyInput,
|
|
ServeRequest,
|
|
ServerResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
# Add your API endpoints here
|
|
|
|
global_system_app: Optional[SystemApp] = None
|
|
|
|
|
|
def get_service() -> Service:
|
|
"""Get the service instance"""
|
|
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
|
|
|
|
|
get_bearer_token = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@cache
|
|
def _parse_api_keys(api_keys: str) -> List[str]:
|
|
"""Parse the string api keys to a list
|
|
|
|
Args:
|
|
api_keys (str): The string api keys
|
|
|
|
Returns:
|
|
List[str]: The list of api keys
|
|
"""
|
|
if not api_keys:
|
|
return []
|
|
return [key.strip() for key in api_keys.split(",")]
|
|
|
|
|
|
async def check_api_key(
|
|
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
|
request: Request = None,
|
|
service: Service = Depends(get_service),
|
|
) -> Optional[str]:
|
|
"""Check the api key
|
|
|
|
If the api key is not set, allow all.
|
|
|
|
Your can pass the token in you request header like this:
|
|
|
|
.. code-block:: python
|
|
|
|
import requests
|
|
|
|
client_api_key = "your_api_key"
|
|
headers = {"Authorization": "Bearer " + client_api_key}
|
|
res = requests.get("http://test/hello", headers=headers)
|
|
assert res.status_code == 200
|
|
|
|
"""
|
|
if request.url.path.startswith(f"/api/v1"):
|
|
return None
|
|
|
|
|
|
@router.get("/health")
|
|
async def health():
|
|
"""Health check endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
|
async def test_auth():
|
|
"""Test auth endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
# TODO: Compatible with old API, will be modified in the future
|
|
@router.post(
|
|
"/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
|
)
|
|
async def create(
|
|
request: ServeRequest, service: Service = Depends(get_service)
|
|
) -> Result[ServerResponse]:
|
|
"""Create a new Prompt entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(service.create(request))
|
|
|
|
|
|
@router.post(
|
|
"/update",
|
|
response_model=Result[ServerResponse],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def update(
|
|
request: ServeRequest, service: Service = Depends(get_service)
|
|
) -> Result[ServerResponse]:
|
|
"""Update a Prompt entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
try:
|
|
data = service.update(request)
|
|
return Result.succ(data)
|
|
except Exception as e:
|
|
logger.exception("Update prompt failed!")
|
|
return Result.failed(msg=str(e))
|
|
|
|
|
|
@router.post(
|
|
"/delete", response_model=Result[None], dependencies=[Depends(check_api_key)]
|
|
)
|
|
async def delete(
|
|
request: ServeRequest, service: Service = Depends(get_service)
|
|
) -> Result[None]:
|
|
"""Delete a Prompt entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(service.delete(request))
|
|
|
|
|
|
@router.post(
|
|
"/list",
|
|
response_model=Result[List[ServerResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def query(
|
|
request: ServeRequest, service: Service = Depends(get_service)
|
|
) -> Result[List[ServerResponse]]:
|
|
"""Query Prompt entities
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
List[ServerResponse]: The response
|
|
"""
|
|
return Result.succ(service.get_list(request))
|
|
|
|
|
|
@router.post(
|
|
"/query_page",
|
|
response_model=Result[PaginationResult[ServerResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def query_page(
|
|
request: ServeRequest,
|
|
page: Optional[int] = Query(default=1, description="current page"),
|
|
page_size: Optional[int] = Query(default=20, description="page size"),
|
|
service: Service = Depends(get_service),
|
|
) -> Result[PaginationResult[ServerResponse]]:
|
|
"""Query Prompt entities
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
page (int): The page number
|
|
page_size (int): The page size
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(service.get_list_by_page(request, page, page_size))
|
|
|
|
|
|
@router.get(
|
|
"/type/targets",
|
|
response_model=Result,
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def prompt_type_targets(
|
|
prompt_type: str = Query(
|
|
default=PromptType.NORMAL, description="Prompt template type"
|
|
),
|
|
service: Service = Depends(get_service),
|
|
) -> Result:
|
|
"""get Prompt type
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(service.get_type_targets(prompt_type))
|
|
|
|
|
|
@router.post(
|
|
"/template/load",
|
|
response_model=Result,
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def load_template(
|
|
prompt_type: str = Query(
|
|
default=PromptType.NORMAL, description="Prompt template type"
|
|
),
|
|
target: Optional[str] = Query(
|
|
default=None, description="The target to load the template from"
|
|
),
|
|
service: Service = Depends(get_service),
|
|
) -> Result:
|
|
"""load Prompt from target
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(service.load_template(prompt_type, target))
|
|
|
|
|
|
@router.post(
|
|
"/template/debug",
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def template_debug(
|
|
debug_input: PromptDebugInput,
|
|
service: Service = Depends(get_service),
|
|
):
|
|
"""test Prompt
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
headers = {
|
|
"Content-Type": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Transfer-Encoding": "chunked",
|
|
}
|
|
try:
|
|
return StreamingResponse(
|
|
service.debug_prompt(
|
|
debug_input=debug_input,
|
|
),
|
|
headers=headers,
|
|
media_type="text/event-stream",
|
|
)
|
|
except Exception as e:
|
|
return Result.failed(msg=str(e))
|
|
|
|
|
|
@router.post(
|
|
"/response/verify",
|
|
response_model=Result[bool],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def response_verify(
|
|
request: PromptVerifyInput,
|
|
service: Service = Depends(get_service),
|
|
) -> Result[bool]:
|
|
"""test Prompt
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
try:
|
|
return Result.succ(
|
|
service.verify_response(
|
|
request.llm_out, request.prompt_type, request.chat_scene
|
|
)
|
|
)
|
|
except Exception as e:
|
|
return Result.failed(msg=str(e))
|
|
|
|
|
|
def init_endpoints(system_app: SystemApp) -> None:
|
|
"""Initialize the endpoints"""
|
|
global global_system_app
|
|
system_app.register(Service)
|
|
global_system_app = system_app
|