mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 12:30:14 +00:00
feat: APIServer supports embeddings (#1256)
This commit is contained in:
@@ -23,6 +23,8 @@ from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
@@ -51,6 +53,7 @@ class APIServerException(Exception):
|
||||
|
||||
class APISettings(BaseModel):
|
||||
api_keys: Optional[List[str]] = None
|
||||
embedding_bach_size: int = 4
|
||||
|
||||
|
||||
api_settings = APISettings()
|
||||
@@ -181,27 +184,29 @@ class APIServer(BaseComponent):
|
||||
return controller
|
||||
|
||||
async def get_model_instances_or_raise(
|
||||
self, model_name: str
|
||||
self, model_name: str, worker_type: str = "llm"
|
||||
) -> List[ModelInstance]:
|
||||
"""Get healthy model instances with request model name
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
worker_type (str, optional): Worker type. Defaults to "llm".
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get healthy model instances with request model name
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
registry_model_name = f"{model_name}@llm"
|
||||
suffix = f"@{worker_type}"
|
||||
registry_model_name = f"{model_name}{suffix}"
|
||||
model_instances = await registry.get_all_instances(
|
||||
registry_model_name, healthy_only=True
|
||||
)
|
||||
if not model_instances:
|
||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
models = [
|
||||
ins.model_name.split("@llm")[0]
|
||||
ins.model_name.split(suffix)[0]
|
||||
for ins in all_instances
|
||||
if ins.model_name.endswith("@llm")
|
||||
if ins.model_name.endswith(suffix)
|
||||
]
|
||||
if models:
|
||||
models = "&&".join(models)
|
||||
@@ -336,6 +341,25 @@ class APIServer(BaseComponent):
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
async def embeddings_generate(
|
||||
self, model: str, texts: List[str]
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings
|
||||
|
||||
Args:
|
||||
model (str): Model name
|
||||
texts (List[str]): Texts to embed
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings of texts
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
@@ -389,6 +413,40 @@ async def create_chat_completion(
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||
async def create_embeddings(
|
||||
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
texts = request.input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
batch_size = api_settings.embedding_bach_size
|
||||
batches = [
|
||||
texts[i : min(i + batch_size, len(texts))]
|
||||
for i in range(0, len(texts), batch_size)
|
||||
]
|
||||
data = []
|
||||
async_tasks = []
|
||||
for num_batch, batch in enumerate(batches):
|
||||
async_tasks.append(api_server.embeddings_generate(request.model, batch))
|
||||
|
||||
# Request all embeddings in parallel
|
||||
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
|
||||
for num_batch, embeddings in enumerate(batch_embeddings):
|
||||
data += [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": emb,
|
||||
"index": num_batch * batch_size + i,
|
||||
}
|
||||
for i, emb in enumerate(embeddings)
|
||||
]
|
||||
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
@@ -427,6 +485,7 @@ def initialize_apiserver(
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
api_keys: List[str] = None,
|
||||
embedding_batch_size: Optional[int] = None,
|
||||
):
|
||||
global global_system_app
|
||||
global api_settings
|
||||
@@ -434,13 +493,6 @@ def initialize_apiserver(
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
@@ -449,6 +501,9 @@ def initialize_apiserver(
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
if embedding_batch_size:
|
||||
api_settings.embedding_bach_size = embedding_batch_size
|
||||
|
||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||
|
||||
@app.exception_handler(APIServerException)
|
||||
@@ -464,7 +519,15 @@ def initialize_apiserver(
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
# https://github.com/encode/starlette/issues/617
|
||||
cors_app = CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
uvicorn.run(cors_app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
@@ -488,6 +551,7 @@ def run_apiserver():
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
embedding_batch_size=apiserver_params.embedding_batch_size,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user