mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
438
dbgpt/model/cluster/apiserver/api.py
Normal file
438
dbgpt/model/cluster/apiserver/api.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
|
||||
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any, Generator
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import shortuuid
|
||||
import json
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastchat.protocol.api_protocol import (
|
||||
APIChatCompletionRequest,
|
||||
)
|
||||
from fastchat.serve.openai_api_server import create_error_response, check_requests
|
||||
from fastchat.constants import ErrorCode
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from dbgpt.model.cluster import ModelRegistry
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIServerException(Exception):
|
||||
def __init__(self, code: int, message: str):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class APISettings(BaseSettings):
|
||||
api_keys: Optional[List[str]] = None
|
||||
|
||||
|
||||
api_settings = APISettings()
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
) -> str:
|
||||
if api_settings.api_keys:
|
||||
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
class APIServer(BaseComponent):
|
||||
name = ComponentType.MODEL_API_SERVER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def get_worker_manager(self) -> WorkerManager:
|
||||
"""Get the worker manager component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get worker manager component instance
|
||||
"""
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
if not worker_manager:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
|
||||
)
|
||||
return worker_manager
|
||||
|
||||
def get_model_registry(self) -> ModelRegistry:
|
||||
"""Get the model registry component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get model registry component instance
|
||||
"""
|
||||
|
||||
controller = self.system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry
|
||||
)
|
||||
if not controller:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
|
||||
)
|
||||
return controller
|
||||
|
||||
async def get_model_instances_or_raise(
|
||||
self, model_name: str
|
||||
) -> List[ModelInstance]:
|
||||
"""Get healthy model instances with request model name
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get healthy model instances with request model name
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
registry_model_name = f"{model_name}@llm"
|
||||
model_instances = await registry.get_all_instances(
|
||||
registry_model_name, healthy_only=True
|
||||
)
|
||||
if not model_instances:
|
||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
models = [
|
||||
ins.model_name.split("@llm")[0]
|
||||
for ins in all_instances
|
||||
if ins.model_name.endswith("@llm")
|
||||
]
|
||||
if models:
|
||||
models = "&&".join(models)
|
||||
message = f"Only {models} allowed now, your model {model_name}"
|
||||
else:
|
||||
message = f"No models allowed now, your model {model_name}"
|
||||
raise APIServerException(ErrorCode.INVALID_MODEL, message)
|
||||
return model_instances
|
||||
|
||||
async def get_available_models(self) -> ModelList:
|
||||
"""Return available models
|
||||
|
||||
Just include LLM and embedding models.
|
||||
|
||||
Returns:
|
||||
List[ModelList]: The list of models.
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
model_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
model_name_set = set()
|
||||
for inst in model_instances:
|
||||
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
|
||||
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
|
||||
model_name_set.add(name)
|
||||
models = list(model_name_set)
|
||||
models.sort()
|
||||
# TODO: return real model permission details
|
||||
model_cards = []
|
||||
for m in models:
|
||||
model_cards.append(
|
||||
ModelCard(
|
||||
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
|
||||
)
|
||||
)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> Generator[str, Any, None]:
|
||||
"""Chat stream completion generator
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
for i in range(n):
|
||||
# First chunk with role
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_generate(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate completion
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
choices = []
|
||||
chat_completions = []
|
||||
for i in range(n):
|
||||
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||
chat_completions.append(model_output)
|
||||
try:
|
||||
all_tasks = await asyncio.gather(*chat_completions)
|
||||
except Exception as e:
|
||||
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||
usage = UsageInfo()
|
||||
for i, model_output in enumerate(all_tasks):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
return create_error_response(model_output.error_code, model_output.text)
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role="assistant", content=model_output.text),
|
||||
finish_reason=model_output.finish_reason or "stop",
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.parse_obj(model_output.usage)
|
||||
for usage_key, usage_value in task_usage.dict().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
|
||||
)
|
||||
if not api_server:
|
||||
global_system_app.register(APIServer)
|
||||
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
|
||||
return await api_server.get_available_models()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def create_chat_completion(
|
||||
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model)
|
||||
error_check_ret = check_requests(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
params = {
|
||||
"model": request.model,
|
||||
"messages": ModelMessage.to_dict_list(
|
||||
ModelMessage.from_openai_messages(request.messages)
|
||||
),
|
||||
"echo": False,
|
||||
}
|
||||
if request.temperature:
|
||||
params["temperature"] = request.temperature
|
||||
if request.top_p:
|
||||
params["top_p"] = request.top_p
|
||||
if request.max_tokens:
|
||||
params["max_new_tokens"] = request.max_tokens
|
||||
if request.stop:
|
||||
params["stop"] = request.stop
|
||||
if request.user:
|
||||
params["user"] = request.user
|
||||
|
||||
# TODO check token length
|
||||
if request.stream:
|
||||
generator = api_server.chat_completion_stream_generator(
|
||||
request.model, params, request.n
|
||||
)
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
if not system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
):
|
||||
# Register model registry if not exist
|
||||
registry = ModelRegistryClient(controller_addr)
|
||||
registry.name = ComponentType.MODEL_REGISTRY.value
|
||||
system_app.register_instance(registry)
|
||||
|
||||
registry = system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
|
||||
# Register worker manager component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
or_register_component=_DefaultWorkerManagerFactory,
|
||||
worker_manager=worker_manager,
|
||||
)
|
||||
# Register api server component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
|
||||
)
|
||||
|
||||
|
||||
def initialize_apiserver(
|
||||
controller_addr: str,
|
||||
app=None,
|
||||
system_app: SystemApp = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
api_keys: List[str] = None,
|
||||
):
|
||||
global global_system_app
|
||||
global api_settings
|
||||
embedded_mod = True
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
global_system_app = system_app
|
||||
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||
|
||||
@app.exception_handler(APIServerException)
|
||||
async def validation_apiserver_exception_handler(request, exc: APIServerException):
|
||||
return create_error_response(exc.code, exc.message)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
||||
|
||||
_initialize_all(controller_addr, system_app)
|
||||
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "apiserver_"
|
||||
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
|
||||
ModelAPIServerParameters,
|
||||
env_prefixes=[env_prefix],
|
||||
)
|
||||
setup_logging(
|
||||
"dbgpt",
|
||||
logging_level=apiserver_params.log_level,
|
||||
logger_filename=apiserver_params.log_file,
|
||||
)
|
||||
api_keys = None
|
||||
if apiserver_params.api_keys:
|
||||
api_keys = apiserver_params.api_keys.strip().split(",")
|
||||
|
||||
initialize_apiserver(
|
||||
apiserver_params.controller_addr,
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_apiserver()
|
Reference in New Issue
Block a user