feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -3,6 +3,7 @@
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
"""
import asyncio
import json
import logging
@@ -34,6 +35,8 @@ from dbgpt.core.schema.api import (
ModelCard,
ModelList,
ModelPermission,
RelevanceRequest,
RelevanceResponse,
UsageInfo,
)
from dbgpt.model.base import ModelInstance
@@ -368,6 +371,28 @@ class APIServer(BaseComponent):
}
return await worker_manager.embeddings(params)
async def relevance_generate(
self, model: str, query: str, texts: List[str]
) -> List[float]:
"""Generate embeddings
Args:
model (str): Model name
query (str): Query text
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
"query": query,
}
scores = await worker_manager.embeddings(params)
return scores[0]
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
@@ -456,6 +481,26 @@ async def create_embeddings(
)
@router.post(
"/v1/beta/relevance",
dependencies=[Depends(check_api_key)],
response_model=RelevanceResponse,
)
async def create_embeddings(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
):
"""Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
scores = await api_server.relevance_generate(
request.model, request.query, request.documents
)
return model_to_dict(
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
exclude_none=True,
)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory