mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-11 13:12:18 +00:00
154 lines
4.1 KiB
Python
154 lines
4.1 KiB
Python
import logging
|
|
from functools import cache
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
from dbgpt.component import ComponentType, SystemApp
|
|
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
|
from dbgpt.serve.core import Result
|
|
from dbgpt.serve.evaluate.api.schemas import EvaluateServeRequest
|
|
from dbgpt.serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME
|
|
from dbgpt.serve.evaluate.service.service import Service
|
|
|
|
from ...prompt.service.service import Service as PromptService
|
|
|
|
router = APIRouter()
|
|
|
|
# Add your API endpoints here
|
|
|
|
global_system_app: Optional[SystemApp] = None
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_service() -> Service:
|
|
"""Get the service instance"""
|
|
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
|
|
|
|
|
def get_prompt_service() -> PromptService:
|
|
return global_system_app.get_component("dbgpt_serve_prompt_service", PromptService)
|
|
|
|
|
|
def get_worker_manager() -> WorkerManager:
|
|
worker_manager = global_system_app.get_component(
|
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
|
).create()
|
|
return worker_manager
|
|
|
|
|
|
def get_model_controller() -> BaseModelController:
|
|
controller = global_system_app.get_component(
|
|
ComponentType.MODEL_CONTROLLER, BaseModelController
|
|
)
|
|
return controller
|
|
|
|
|
|
get_bearer_token = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@cache
|
|
def _parse_api_keys(api_keys: str) -> List[str]:
|
|
"""Parse the string api keys to a list
|
|
|
|
Args:
|
|
api_keys (str): The string api keys
|
|
|
|
Returns:
|
|
List[str]: The list of api keys
|
|
"""
|
|
if not api_keys:
|
|
return []
|
|
return [key.strip() for key in api_keys.split(",")]
|
|
|
|
|
|
async def check_api_key(
|
|
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
|
service: Service = Depends(get_service),
|
|
) -> Optional[str]:
|
|
"""Check the api key
|
|
|
|
If the api key is not set, allow all.
|
|
|
|
Your can pass the token in you request header like this:
|
|
|
|
.. code-block:: python
|
|
|
|
import requests
|
|
|
|
client_api_key = "your_api_key"
|
|
headers = {"Authorization": "Bearer " + client_api_key}
|
|
res = requests.get("http://test/hello", headers=headers)
|
|
assert res.status_code == 200
|
|
|
|
"""
|
|
if service.config.api_keys:
|
|
api_keys = _parse_api_keys(service.config.api_keys)
|
|
if auth is None or (token := auth.credentials) not in api_keys:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": {
|
|
"message": "",
|
|
"type": "invalid_request_error",
|
|
"param": None,
|
|
"code": "invalid_api_key",
|
|
}
|
|
},
|
|
)
|
|
return token
|
|
else:
|
|
# api_keys not set; allow all
|
|
return None
|
|
|
|
|
|
@router.get("/health", dependencies=[Depends(check_api_key)])
|
|
async def health():
|
|
"""Health check endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
|
async def test_auth():
|
|
"""Test auth endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/scenes")
|
|
async def get_scenes():
|
|
scene_list = [{"recall": "召回评测"}, {"app": "应用评测"}]
|
|
|
|
return Result.succ(scene_list)
|
|
|
|
|
|
@router.post("/evaluation")
|
|
async def evaluation(
|
|
request: EvaluateServeRequest,
|
|
service: Service = Depends(get_service),
|
|
) -> Result:
|
|
"""Evaluate results by the scene
|
|
|
|
Args:
|
|
request (EvaluateServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(
|
|
await service.run_evaluation(
|
|
request.scene_key,
|
|
request.scene_value,
|
|
request.datasets,
|
|
request.context,
|
|
request.evaluate_metrics,
|
|
)
|
|
)
|
|
|
|
|
|
def init_endpoints(system_app: SystemApp) -> None:
|
|
"""Initialize the endpoints"""
|
|
global global_system_app
|
|
system_app.register(Service)
|
|
global_system_app = system_app
|