feat: APIServer supports embeddings (#1256)

This commit is contained in:
Fangyin Cheng
2024-03-05 20:21:37 +08:00
committed by GitHub
parent 5f3ee35804
commit 74ec8e52cd
9 changed files with 414 additions and 40 deletions

View File

@@ -23,6 +23,8 @@ from fastchat.protocol.openai_api_protocol import (
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ModelCard,
ModelList,
ModelPermission,
@@ -51,6 +53,7 @@ class APIServerException(Exception):
class APISettings(BaseModel):
api_keys: Optional[List[str]] = None
embedding_bach_size: int = 4
api_settings = APISettings()
@@ -181,27 +184,29 @@ class APIServer(BaseComponent):
return controller
async def get_model_instances_or_raise(
self, model_name: str
self, model_name: str, worker_type: str = "llm"
) -> List[ModelInstance]:
"""Get healthy model instances with request model name
Args:
model_name (str): Model name
worker_type (str, optional): Worker type. Defaults to "llm".
Raises:
APIServerException: If can't get healthy model instances with request model name
"""
registry = self.get_model_registry()
registry_model_name = f"{model_name}@llm"
suffix = f"@{worker_type}"
registry_model_name = f"{model_name}{suffix}"
model_instances = await registry.get_all_instances(
registry_model_name, healthy_only=True
)
if not model_instances:
all_instances = await registry.get_all_model_instances(healthy_only=True)
models = [
ins.model_name.split("@llm")[0]
ins.model_name.split(suffix)[0]
for ins in all_instances
if ins.model_name.endswith("@llm")
if ins.model_name.endswith(suffix)
]
if models:
models = "&&".join(models)
@@ -336,6 +341,25 @@ class APIServer(BaseComponent):
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
async def embeddings_generate(
self, model: str, texts: List[str]
) -> List[List[float]]:
"""Generate embeddings
Args:
model (str): Model name
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
}
return await worker_manager.embeddings(params)
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
@@ -389,6 +413,40 @@ async def create_chat_completion(
return await api_server.chat_completion_generate(request.model, params, request.n)
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
):
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
texts = request.input
if isinstance(texts, str):
texts = [texts]
batch_size = api_settings.embedding_bach_size
batches = [
texts[i : min(i + batch_size, len(texts))]
for i in range(0, len(texts), batch_size)
]
data = []
async_tasks = []
for num_batch, batch in enumerate(batches):
async_tasks.append(api_server.embeddings_generate(request.model, batch))
# Request all embeddings in parallel
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
for num_batch, embeddings in enumerate(batch_embeddings):
data += [
{
"object": "embedding",
"embedding": emb,
"index": num_batch * batch_size + i,
}
for i, emb in enumerate(embeddings)
]
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
exclude_none=True
)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
@@ -427,6 +485,7 @@ def initialize_apiserver(
host: str = None,
port: int = None,
api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
):
global global_system_app
global api_settings
@@ -434,13 +493,6 @@ def initialize_apiserver(
if not app:
embedded_mod = False
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
if not system_app:
system_app = SystemApp(app)
@@ -449,6 +501,9 @@ def initialize_apiserver(
if api_keys:
api_settings.api_keys = api_keys
if embedding_batch_size:
api_settings.embedding_bach_size = embedding_batch_size
app.include_router(router, prefix="/api", tags=["APIServer"])
@app.exception_handler(APIServerException)
@@ -464,7 +519,15 @@ def initialize_apiserver(
if not embedded_mod:
import uvicorn
uvicorn.run(app, host=host, port=port, log_level="info")
# https://github.com/encode/starlette/issues/617
cors_app = CORSMiddleware(
app=app,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
uvicorn.run(cors_app, host=host, port=port, log_level="info")
def run_apiserver():
@@ -488,6 +551,7 @@ def run_apiserver():
host=apiserver_params.host,
port=apiserver_params.port,
api_keys=api_keys,
embedding_batch_size=apiserver_params.embedding_batch_size,
)