mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 17:39:54 +00:00
877 lines
33 KiB
Python
877 lines
33 KiB
Python
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
|
|
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
|
|
|
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, Dict, Generator, List, Optional
|
|
|
|
import shortuuid
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
from dbgpt._private.pydantic import BaseModel, model_to_dict, model_to_json
|
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
|
from dbgpt.core import ModelOutput
|
|
from dbgpt.core.interface.message import ModelMessage
|
|
from dbgpt.core.schema.api import (
|
|
APIChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseChoice,
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatMessage,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseChoice,
|
|
CompletionResponseStreamChoice,
|
|
CompletionStreamResponse,
|
|
DeltaMessage,
|
|
EmbeddingsRequest,
|
|
EmbeddingsResponse,
|
|
ErrorCode,
|
|
ErrorResponse,
|
|
ModelCard,
|
|
ModelList,
|
|
ModelPermission,
|
|
RelevanceRequest,
|
|
RelevanceResponse,
|
|
UsageInfo,
|
|
)
|
|
from dbgpt.model.base import ModelInstance
|
|
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
|
from dbgpt.model.cluster.registry import ModelRegistry
|
|
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
|
|
from dbgpt.util.fastapi import create_app
|
|
from dbgpt.util.parameter_utils import EnvArgumentParser
|
|
from dbgpt.util.tracer import initialize_tracer, root_tracer
|
|
from dbgpt.util.utils import setup_logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class APIServerException(Exception):
|
|
def __init__(self, code: int, message: str):
|
|
self.code = code
|
|
self.message = message
|
|
|
|
|
|
class APISettings(BaseModel):
|
|
api_keys: Optional[List[str]] = None
|
|
embedding_bach_size: int = 4
|
|
ignore_stop_exceeds_error: bool = False
|
|
|
|
|
|
api_settings = APISettings()
|
|
get_bearer_token = HTTPBearer(auto_error=False)
|
|
|
|
|
|
async def check_api_key(
|
|
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
|
) -> str:
|
|
if api_settings.api_keys:
|
|
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": {
|
|
"message": "",
|
|
"type": "invalid_request_error",
|
|
"param": None,
|
|
"code": "invalid_api_key",
|
|
}
|
|
},
|
|
)
|
|
return token
|
|
else:
|
|
# api_keys not set; allow all
|
|
return None
|
|
|
|
|
|
def create_error_response(code: int, message: str) -> JSONResponse:
|
|
"""Copy from fastchat.serve.openai_api_server.check_requests
|
|
|
|
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
|
"""
|
|
return JSONResponse(
|
|
model_to_dict(ErrorResponse(message=message, code=code)), status_code=400
|
|
)
|
|
|
|
|
|
def check_requests(request) -> Optional[JSONResponse]:
|
|
"""Copy from fastchat.serve.openai_api_server.create_error_response
|
|
|
|
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
|
"""
|
|
# Check all params
|
|
if request.max_tokens is not None and request.max_tokens <= 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
|
|
)
|
|
if request.n is not None and request.n <= 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.n} is less than the minimum of 1 - 'n'",
|
|
)
|
|
if request.temperature is not None and request.temperature < 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
|
|
)
|
|
if request.temperature is not None and request.temperature > 2:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
|
|
)
|
|
if request.top_p is not None and request.top_p < 0:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
|
|
)
|
|
if request.top_p is not None and request.top_p > 1:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
|
|
)
|
|
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
|
|
)
|
|
if request.stop is not None and (
|
|
not isinstance(request.stop, str) and not isinstance(request.stop, list)
|
|
):
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"{request.stop} is not valid under any of the given schemas - 'stop'",
|
|
)
|
|
if request.stop and isinstance(request.stop, list) and len(request.stop) > 4:
|
|
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop
|
|
if not api_settings.ignore_stop_exceeds_error:
|
|
return create_error_response(
|
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
|
f"Invalid 'stop': array too long. Expected an array with maximum length 4, but got an array with length {len(request.stop)} instead.",
|
|
)
|
|
else:
|
|
request.stop = request.stop[:4]
|
|
|
|
return None
|
|
|
|
|
|
class APIServer(BaseComponent):
|
|
name = ComponentType.MODEL_API_SERVER
|
|
|
|
def init_app(self, system_app: SystemApp):
|
|
self.system_app = system_app
|
|
|
|
def get_worker_manager(self) -> WorkerManager:
|
|
"""Get the worker manager component instance
|
|
|
|
Raises:
|
|
APIServerException: If can't get worker manager component instance
|
|
"""
|
|
worker_manager = self.system_app.get_component(
|
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
|
).create()
|
|
if not worker_manager:
|
|
raise APIServerException(
|
|
ErrorCode.INTERNAL_ERROR,
|
|
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
|
|
)
|
|
return worker_manager
|
|
|
|
def get_model_registry(self) -> ModelRegistry:
|
|
"""Get the model registry component instance
|
|
|
|
Raises:
|
|
APIServerException: If can't get model registry component instance
|
|
"""
|
|
|
|
controller = self.system_app.get_component(
|
|
ComponentType.MODEL_REGISTRY, ModelRegistry
|
|
)
|
|
if not controller:
|
|
raise APIServerException(
|
|
ErrorCode.INTERNAL_ERROR,
|
|
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
|
|
)
|
|
return controller
|
|
|
|
async def get_model_instances_or_raise(
|
|
self, model_name: str, worker_type: str = "llm"
|
|
) -> List[ModelInstance]:
|
|
"""Get healthy model instances with request model name
|
|
|
|
Args:
|
|
model_name (str): Model name
|
|
worker_type (str, optional): Worker type. Defaults to "llm".
|
|
|
|
Raises:
|
|
APIServerException: If can't get healthy model instances with request model name
|
|
"""
|
|
registry = self.get_model_registry()
|
|
suffix = f"@{worker_type}"
|
|
registry_model_name = f"{model_name}{suffix}"
|
|
model_instances = await registry.get_all_instances(
|
|
registry_model_name, healthy_only=True
|
|
)
|
|
if not model_instances:
|
|
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
|
models = [
|
|
ins.model_name.split(suffix)[0]
|
|
for ins in all_instances
|
|
if ins.model_name.endswith(suffix)
|
|
]
|
|
if models:
|
|
models = "&&".join(models)
|
|
message = f"Only {models} allowed now, your model {model_name}"
|
|
else:
|
|
message = f"No models allowed now, your model {model_name}"
|
|
raise APIServerException(ErrorCode.INVALID_MODEL, message)
|
|
return model_instances
|
|
|
|
async def get_available_models(self) -> ModelList:
|
|
"""Return available models
|
|
|
|
Just include LLM and embedding models.
|
|
|
|
Returns:
|
|
List[ModelList]: The list of models.
|
|
"""
|
|
registry = self.get_model_registry()
|
|
model_instances = await registry.get_all_model_instances(healthy_only=True)
|
|
model_name_set = set()
|
|
for inst in model_instances:
|
|
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
|
|
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
|
|
model_name_set.add(name)
|
|
models = list(model_name_set)
|
|
models.sort()
|
|
# TODO: return real model permission details
|
|
model_cards = []
|
|
for m in models:
|
|
model_cards.append(
|
|
ModelCard(
|
|
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
|
|
)
|
|
)
|
|
return ModelList(data=model_cards)
|
|
|
|
async def chat_completion_stream_generator(
|
|
self, model_name: str, params: Dict[str, Any], n: int
|
|
) -> Generator[str, Any, None]:
|
|
"""Chat stream completion generator
|
|
|
|
Args:
|
|
model_name (str): Model name
|
|
params (Dict[str, Any]): The parameters pass to model worker
|
|
n (int): How many completions to generate for each prompt.
|
|
"""
|
|
worker_manager = self.get_worker_manager()
|
|
id = f"chatcmpl-{shortuuid.random()}"
|
|
finish_stream_events = []
|
|
curr_usage = UsageInfo()
|
|
last_usage = UsageInfo()
|
|
for i in range(n):
|
|
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
|
last_usage.completion_tokens += curr_usage.completion_tokens
|
|
last_usage.total_tokens += curr_usage.total_tokens
|
|
|
|
# First chunk with role
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=DeltaMessage(role="assistant"),
|
|
finish_reason=None,
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=id,
|
|
choices=[choice_data],
|
|
model=model_name,
|
|
usage=last_usage,
|
|
)
|
|
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
|
yield f"data: {json_data}\n\n"
|
|
|
|
previous_text = ""
|
|
async for model_output in worker_manager.generate_stream(params):
|
|
model_output: ModelOutput = model_output
|
|
if model_output.error_code != 0:
|
|
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
decoded_unicode = model_output.text.replace("\ufffd", "")
|
|
delta_text = decoded_unicode[len(previous_text) :]
|
|
previous_text = (
|
|
decoded_unicode
|
|
if len(decoded_unicode) > len(previous_text)
|
|
else previous_text
|
|
)
|
|
|
|
if len(delta_text) == 0:
|
|
delta_text = None
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=DeltaMessage(content=delta_text),
|
|
finish_reason=model_output.finish_reason,
|
|
)
|
|
has_usage = False
|
|
if model_output.usage:
|
|
curr_usage = UsageInfo.model_validate(model_output.usage)
|
|
has_usage = True
|
|
usage = UsageInfo(
|
|
prompt_tokens=last_usage.prompt_tokens
|
|
+ curr_usage.prompt_tokens,
|
|
total_tokens=last_usage.total_tokens + curr_usage.total_tokens,
|
|
completion_tokens=last_usage.completion_tokens
|
|
+ curr_usage.completion_tokens,
|
|
)
|
|
else:
|
|
has_usage = False
|
|
usage = UsageInfo()
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=id, choices=[choice_data], model=model_name, usage=usage
|
|
)
|
|
if delta_text is None:
|
|
if model_output.finish_reason is not None:
|
|
finish_stream_events.append(chunk)
|
|
if not has_usage:
|
|
continue
|
|
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
|
yield f"data: {json_data}\n\n"
|
|
|
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
|
for finish_chunk in finish_stream_events:
|
|
json_data = model_to_json(
|
|
finish_chunk, exclude_unset=True, ensure_ascii=False
|
|
)
|
|
yield f"data: {json_data}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
async def chat_completion_generate(
|
|
self, model_name: str, params: Dict[str, Any], n: int
|
|
) -> ChatCompletionResponse:
|
|
"""Generate completion
|
|
Args:
|
|
model_name (str): Model name
|
|
params (Dict[str, Any]): The parameters pass to model worker
|
|
n (int): How many completions to generate for each prompt.
|
|
"""
|
|
worker_manager: WorkerManager = self.get_worker_manager()
|
|
choices = []
|
|
chat_completions = []
|
|
for i in range(n):
|
|
model_output = asyncio.create_task(worker_manager.generate(params))
|
|
chat_completions.append(model_output)
|
|
try:
|
|
all_tasks = await asyncio.gather(*chat_completions)
|
|
except Exception as e:
|
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
|
usage = UsageInfo()
|
|
for i, model_output in enumerate(all_tasks):
|
|
model_output: ModelOutput = model_output
|
|
if model_output.error_code != 0:
|
|
return create_error_response(model_output.error_code, model_output.text)
|
|
choices.append(
|
|
ChatCompletionResponseChoice(
|
|
index=i,
|
|
message=ChatMessage(role="assistant", content=model_output.text),
|
|
finish_reason=model_output.finish_reason or "stop",
|
|
)
|
|
)
|
|
if model_output.usage:
|
|
task_usage = UsageInfo.model_validate(model_output.usage)
|
|
for usage_key, usage_value in model_to_dict(task_usage).items():
|
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
|
|
|
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
|
|
|
async def completion_stream_generator(
|
|
self, request: CompletionRequest, params: Dict
|
|
):
|
|
worker_manager = self.get_worker_manager()
|
|
id = f"cmpl-{shortuuid.random()}"
|
|
finish_stream_events = []
|
|
params["span_id"] = root_tracer.get_current_span_id()
|
|
curr_usage = UsageInfo()
|
|
last_usage = UsageInfo()
|
|
for text in request.prompt:
|
|
for i in range(request.n):
|
|
params["prompt"] = text
|
|
previous_text = ""
|
|
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
|
last_usage.completion_tokens += curr_usage.completion_tokens
|
|
last_usage.total_tokens += curr_usage.total_tokens
|
|
|
|
async for model_output in worker_manager.generate_stream(params):
|
|
model_output: ModelOutput = model_output
|
|
if model_output.error_code != 0:
|
|
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
decoded_unicode = model_output.text.replace("\ufffd", "")
|
|
delta_text = decoded_unicode[len(previous_text) :]
|
|
previous_text = (
|
|
decoded_unicode
|
|
if len(decoded_unicode) > len(previous_text)
|
|
else previous_text
|
|
)
|
|
|
|
if len(delta_text) == 0:
|
|
delta_text = None
|
|
|
|
choice_data = CompletionResponseStreamChoice(
|
|
index=i,
|
|
text=delta_text or "",
|
|
# TODO: logprobs
|
|
logprobs=None,
|
|
finish_reason=model_output.finish_reason,
|
|
)
|
|
if model_output.usage:
|
|
curr_usage = UsageInfo.model_validate(model_output.usage)
|
|
usage = UsageInfo(
|
|
prompt_tokens=last_usage.prompt_tokens
|
|
+ curr_usage.prompt_tokens,
|
|
total_tokens=last_usage.total_tokens
|
|
+ curr_usage.total_tokens,
|
|
completion_tokens=last_usage.completion_tokens
|
|
+ curr_usage.completion_tokens,
|
|
)
|
|
else:
|
|
usage = UsageInfo()
|
|
chunk = CompletionStreamResponse(
|
|
id=id,
|
|
object="text_completion",
|
|
choices=[choice_data],
|
|
model=request.model,
|
|
usage=UsageInfo.model_validate(usage),
|
|
)
|
|
if delta_text is None:
|
|
if model_output.finish_reason is not None:
|
|
finish_stream_events.append(chunk)
|
|
continue
|
|
json_data = model_to_json(
|
|
chunk, exclude_unset=True, ensure_ascii=False
|
|
)
|
|
yield f"data: {json_data}\n\n"
|
|
last_usage = curr_usage
|
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
|
for finish_chunk in finish_stream_events:
|
|
json_data = model_to_json(
|
|
finish_chunk, exclude_unset=True, ensure_ascii=False
|
|
)
|
|
yield f"data: {json_data}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
async def completion_generate(
|
|
self, request: CompletionRequest, params: Dict[str, Any]
|
|
):
|
|
worker_manager: WorkerManager = self.get_worker_manager()
|
|
choices = []
|
|
completions = []
|
|
for text in request.prompt:
|
|
for i in range(request.n):
|
|
params["prompt"] = text
|
|
model_output = asyncio.create_task(worker_manager.generate(params))
|
|
completions.append(model_output)
|
|
try:
|
|
all_tasks = await asyncio.gather(*completions)
|
|
except Exception as e:
|
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
|
usage = UsageInfo()
|
|
for i, model_output in enumerate(all_tasks):
|
|
model_output: ModelOutput = model_output
|
|
if model_output.error_code != 0:
|
|
return create_error_response(model_output.error_code, model_output.text)
|
|
choices.append(
|
|
CompletionResponseChoice(
|
|
index=i,
|
|
text=model_output.text,
|
|
finish_reason=model_output.finish_reason,
|
|
)
|
|
)
|
|
if model_output.usage:
|
|
task_usage = UsageInfo.model_validate(model_output.usage)
|
|
for usage_key, usage_value in model_to_dict(task_usage).items():
|
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
|
return CompletionResponse(
|
|
model=request.model, choices=choices, usage=UsageInfo.model_validate(usage)
|
|
)
|
|
|
|
async def embeddings_generate(
|
|
self,
|
|
model: str,
|
|
texts: List[str],
|
|
span_id: Optional[str] = None,
|
|
) -> List[List[float]]:
|
|
"""Generate embeddings
|
|
|
|
Args:
|
|
model (str): Model name
|
|
texts (List[str]): Texts to embed
|
|
span_id (Optional[str], optional): The span id. Defaults to None.
|
|
|
|
Returns:
|
|
List[List[float]]: The embeddings of texts
|
|
"""
|
|
with root_tracer.start_span(
|
|
"dbgpt.model.apiserver.generate_embeddings",
|
|
parent_span_id=span_id,
|
|
metadata={
|
|
"model": model,
|
|
},
|
|
):
|
|
worker_manager: WorkerManager = self.get_worker_manager()
|
|
params = {
|
|
"input": texts,
|
|
"model": model,
|
|
}
|
|
return await worker_manager.embeddings(params)
|
|
|
|
async def relevance_generate(
|
|
self, model: str, query: str, texts: List[str]
|
|
) -> List[float]:
|
|
"""Generate embeddings
|
|
|
|
Args:
|
|
model (str): Model name
|
|
query (str): Query text
|
|
texts (List[str]): Texts to embed
|
|
|
|
Returns:
|
|
List[List[float]]: The embeddings of texts
|
|
"""
|
|
worker_manager: WorkerManager = self.get_worker_manager()
|
|
params = {
|
|
"input": texts,
|
|
"model": model,
|
|
"query": query,
|
|
}
|
|
scores = await worker_manager.embeddings(params)
|
|
return scores[0]
|
|
|
|
|
|
def get_api_server() -> APIServer:
|
|
api_server = global_system_app.get_component(
|
|
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
|
|
)
|
|
if not api_server:
|
|
global_system_app.register(APIServer)
|
|
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
|
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
|
|
return await api_server.get_available_models()
|
|
|
|
|
|
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
|
async def create_chat_completion(
|
|
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
|
|
):
|
|
await api_server.get_model_instances_or_raise(request.model)
|
|
error_check_ret = check_requests(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
params = {
|
|
"model": request.model,
|
|
"messages": ModelMessage.to_dict_list(
|
|
ModelMessage.from_openai_messages(request.messages)
|
|
),
|
|
"echo": False,
|
|
}
|
|
if request.temperature:
|
|
params["temperature"] = request.temperature
|
|
if request.top_p:
|
|
params["top_p"] = request.top_p
|
|
if request.max_tokens:
|
|
params["max_new_tokens"] = request.max_tokens
|
|
if request.stop:
|
|
params["stop"] = request.stop
|
|
if request.user:
|
|
params["user"] = request.user
|
|
|
|
# TODO check token length
|
|
trace_kwargs = {
|
|
"operation_name": "dbgpt.model.apiserver.create_chat_completion",
|
|
"metadata": {
|
|
"model": request.model,
|
|
"messages": request.messages,
|
|
"temperature": request.temperature,
|
|
"top_p": request.top_p,
|
|
"max_tokens": request.max_tokens,
|
|
"stop": request.stop,
|
|
"user": request.user,
|
|
},
|
|
}
|
|
if request.stream:
|
|
generator = api_server.chat_completion_stream_generator(
|
|
request.model, params, request.n
|
|
)
|
|
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
|
|
return StreamingResponse(trace_generator, media_type="text/event-stream")
|
|
else:
|
|
with root_tracer.start_span(**trace_kwargs):
|
|
return await api_server.chat_completion_generate(
|
|
request.model, params, request.n
|
|
)
|
|
|
|
|
|
@router.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
|
async def create_completion(
|
|
request: CompletionRequest, api_server: APIServer = Depends(get_api_server)
|
|
):
|
|
await api_server.get_model_instances_or_raise(request.model)
|
|
error_check_ret = check_requests(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
if isinstance(request.prompt, str):
|
|
request.prompt = [request.prompt]
|
|
elif not isinstance(request.prompt, list):
|
|
return create_error_response(
|
|
ErrorCode.VALIDATION_TYPE_ERROR,
|
|
"prompt must be a string or a list of strings",
|
|
)
|
|
elif isinstance(request.prompt, list) and not isinstance(request.prompt[0], str):
|
|
return create_error_response(
|
|
ErrorCode.VALIDATION_TYPE_ERROR,
|
|
"prompt must be a string or a list of strings",
|
|
)
|
|
|
|
params = {
|
|
"model": request.model,
|
|
"prompt": request.prompt,
|
|
"chat_model": False,
|
|
"temperature": request.temperature,
|
|
"max_new_tokens": request.max_tokens,
|
|
"stop": request.stop,
|
|
"top_p": request.top_p,
|
|
"top_k": request.top_k,
|
|
"echo": request.echo,
|
|
"presence_penalty": request.presence_penalty,
|
|
"frequency_penalty": request.frequency_penalty,
|
|
"user": request.user,
|
|
# "use_beam_search": request.use_beam_search,
|
|
# "beam_size": request.beam_size,
|
|
}
|
|
trace_kwargs = {
|
|
"operation_name": "dbgpt.model.apiserver.create_completion",
|
|
"metadata": {k: v for k, v in params.items() if v},
|
|
}
|
|
if request.stream:
|
|
generator = api_server.completion_stream_generator(request, params)
|
|
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
|
|
return StreamingResponse(trace_generator, media_type="text/event-stream")
|
|
else:
|
|
with root_tracer.start_span(**trace_kwargs):
|
|
params["span_id"] = root_tracer.get_current_span_id()
|
|
return await api_server.completion_generate(request, params)
|
|
|
|
|
|
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
|
async def create_embeddings(
|
|
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
|
):
|
|
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
|
texts = request.input
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
batch_size = api_settings.embedding_bach_size
|
|
batches = [
|
|
texts[i : min(i + batch_size, len(texts))]
|
|
for i in range(0, len(texts), batch_size)
|
|
]
|
|
data = []
|
|
async_tasks = []
|
|
for num_batch, batch in enumerate(batches):
|
|
async_tasks.append(
|
|
api_server.embeddings_generate(
|
|
request.model, batch, span_id=root_tracer.get_current_span_id()
|
|
)
|
|
)
|
|
|
|
# Request all embeddings in parallel
|
|
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
|
|
for num_batch, embeddings in enumerate(batch_embeddings):
|
|
data += [
|
|
{
|
|
"object": "embedding",
|
|
"embedding": emb,
|
|
"index": num_batch * batch_size + i,
|
|
}
|
|
for i, emb in enumerate(embeddings)
|
|
]
|
|
return model_to_dict(
|
|
EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()),
|
|
exclude_none=True,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/beta/relevance",
|
|
dependencies=[Depends(check_api_key)],
|
|
response_model=RelevanceResponse,
|
|
)
|
|
async def create_relevance(
|
|
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
|
):
|
|
"""Generate relevance scores for a query and a list of documents."""
|
|
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
|
|
|
with root_tracer.start_span(
|
|
"dbgpt.model.apiserver.generate_relevance",
|
|
metadata={
|
|
"model": request.model,
|
|
"query": request.query,
|
|
},
|
|
):
|
|
scores = await api_server.relevance_generate(
|
|
request.model, request.query, request.documents
|
|
)
|
|
return model_to_dict(
|
|
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
|
|
exclude_none=True,
|
|
)
|
|
|
|
|
|
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
|
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
|
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
|
from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager
|
|
|
|
if not system_app.get_component(
|
|
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
|
):
|
|
# Register model registry if not exist
|
|
registry = ModelRegistryClient(controller_addr)
|
|
registry.name = ComponentType.MODEL_REGISTRY.value
|
|
system_app.register_instance(registry)
|
|
|
|
registry = system_app.get_component(
|
|
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
|
)
|
|
worker_manager = RemoteWorkerManager(registry)
|
|
|
|
# Register worker manager component if not exist
|
|
system_app.get_component(
|
|
ComponentType.WORKER_MANAGER_FACTORY,
|
|
WorkerManagerFactory,
|
|
or_register_component=_DefaultWorkerManagerFactory,
|
|
worker_manager=worker_manager,
|
|
)
|
|
# Register api server component if not exist
|
|
system_app.get_component(
|
|
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
|
|
)
|
|
|
|
|
|
def initialize_apiserver(
|
|
controller_addr: str,
|
|
apiserver_params: Optional[ModelAPIServerParameters] = None,
|
|
app=None,
|
|
system_app: SystemApp = None,
|
|
host: str = None,
|
|
port: int = None,
|
|
api_keys: List[str] = None,
|
|
embedding_batch_size: Optional[int] = None,
|
|
ignore_stop_exceeds_error: bool = False,
|
|
):
|
|
import os
|
|
|
|
from dbgpt.configs.model_config import LOGDIR
|
|
|
|
global global_system_app
|
|
global api_settings
|
|
embedded_mod = True
|
|
if not app:
|
|
embedded_mod = False
|
|
app = create_app()
|
|
|
|
if not system_app:
|
|
system_app = SystemApp(app)
|
|
global_system_app = system_app
|
|
|
|
if apiserver_params:
|
|
initialize_tracer(
|
|
os.path.join(LOGDIR, apiserver_params.tracer_file),
|
|
system_app=system_app,
|
|
root_operation_name="DB-GPT-APIServer",
|
|
tracer_storage_cls=apiserver_params.tracer_storage_cls,
|
|
enable_open_telemetry=apiserver_params.tracer_to_open_telemetry,
|
|
otlp_endpoint=apiserver_params.otel_exporter_otlp_traces_endpoint,
|
|
otlp_insecure=apiserver_params.otel_exporter_otlp_traces_insecure,
|
|
otlp_timeout=apiserver_params.otel_exporter_otlp_traces_timeout,
|
|
)
|
|
|
|
if api_keys:
|
|
api_settings.api_keys = api_keys
|
|
|
|
if embedding_batch_size:
|
|
api_settings.embedding_bach_size = embedding_batch_size
|
|
api_settings.ignore_stop_exceeds_error = ignore_stop_exceeds_error
|
|
|
|
app.include_router(router, prefix="/api", tags=["APIServer"])
|
|
|
|
@app.exception_handler(APIServerException)
|
|
async def validation_apiserver_exception_handler(request, exc: APIServerException):
|
|
return create_error_response(exc.code, exc.message)
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request, exc):
|
|
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
|
|
|
_initialize_all(controller_addr, system_app)
|
|
|
|
if not embedded_mod:
|
|
import uvicorn
|
|
|
|
# https://github.com/encode/starlette/issues/617
|
|
cors_app = CORSMiddleware(
|
|
app=app,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
)
|
|
uvicorn.run(cors_app, host=host, port=port, log_level="info")
|
|
|
|
|
|
def run_apiserver():
|
|
parser = EnvArgumentParser()
|
|
env_prefix = "apiserver_"
|
|
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
|
|
ModelAPIServerParameters,
|
|
env_prefixes=[env_prefix],
|
|
)
|
|
setup_logging(
|
|
"dbgpt",
|
|
logging_level=apiserver_params.log_level,
|
|
logger_filename=apiserver_params.log_file,
|
|
)
|
|
api_keys = None
|
|
if apiserver_params.api_keys:
|
|
api_keys = apiserver_params.api_keys.strip().split(",")
|
|
|
|
initialize_apiserver(
|
|
apiserver_params.controller_addr,
|
|
apiserver_params,
|
|
host=apiserver_params.host,
|
|
port=apiserver_params.port,
|
|
api_keys=api_keys,
|
|
embedding_batch_size=apiserver_params.embedding_batch_size,
|
|
ignore_stop_exceeds_error=apiserver_params.ignore_stop_exceeds_error,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_apiserver()
|