import logging from functools import cache from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import ComponentType, SystemApp from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from dbgpt.serve.core import Result from dbgpt.serve.evaluate.api.schemas import EvaluateServeRequest from dbgpt.serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME from dbgpt.serve.evaluate.service.service import Service from ...prompt.service.service import Service as PromptService router = APIRouter() # Add your API endpoints here global_system_app: Optional[SystemApp] = None logger = logging.getLogger(__name__) def get_service() -> Service: """Get the service instance""" return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) def get_prompt_service() -> PromptService: return global_system_app.get_component("dbgpt_serve_prompt_service", PromptService) def get_worker_manager() -> WorkerManager: worker_manager = global_system_app.get_component( ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() return worker_manager def get_model_controller() -> BaseModelController: controller = global_system_app.get_component( ComponentType.MODEL_CONTROLLER, BaseModelController ) return controller get_bearer_token = HTTPBearer(auto_error=False) @cache def _parse_api_keys(api_keys: str) -> List[str]: """Parse the string api keys to a list Args: api_keys (str): The string api keys Returns: List[str]: The list of api keys """ if not api_keys: return [] return [key.strip() for key in api_keys.split(",")] async def check_api_key( auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), service: Service = Depends(get_service), ) -> Optional[str]: """Check the api key If the api key is not set, allow all. Your can pass the token in you request header like this: .. code-block:: python import requests client_api_key = "your_api_key" headers = {"Authorization": "Bearer " + client_api_key} res = requests.get("http://test/hello", headers=headers) assert res.status_code == 200 """ if service.config.api_keys: api_keys = _parse_api_keys(service.config.api_keys) if auth is None or (token := auth.credentials) not in api_keys: raise HTTPException( status_code=401, detail={ "error": { "message": "", "type": "invalid_request_error", "param": None, "code": "invalid_api_key", } }, ) return token else: # api_keys not set; allow all return None @router.get("/health", dependencies=[Depends(check_api_key)]) async def health(): """Health check endpoint""" return {"status": "ok"} @router.get("/test_auth", dependencies=[Depends(check_api_key)]) async def test_auth(): """Test auth endpoint""" return {"status": "ok"} @router.get("/scenes") async def get_scenes(): scene_list = [{"recall": "召回评测"}, {"app": "应用评测"}] return Result.succ(scene_list) @router.post("/evaluation") async def evaluation( request: EvaluateServeRequest, service: Service = Depends(get_service), ) -> Result: """Evaluate results by the scene Args: request (EvaluateServeRequest): The request service (Service): The service Returns: ServerResponse: The response """ return Result.succ( await service.run_evaluation( request.scene_key, request.scene_value, request.datasets, request.context, request.evaluate_metrics, ) ) def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" global global_system_app system_app.register(Service) global_system_app = system_app