mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
@@ -34,6 +35,8 @@ from dbgpt.core.schema.api import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
RelevanceRequest,
|
||||
RelevanceResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from dbgpt.model.base import ModelInstance
|
||||
@@ -368,6 +371,28 @@ class APIServer(BaseComponent):
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
async def relevance_generate(
|
||||
self, model: str, query: str, texts: List[str]
|
||||
) -> List[float]:
|
||||
"""Generate embeddings
|
||||
|
||||
Args:
|
||||
model (str): Model name
|
||||
query (str): Query text
|
||||
texts (List[str]): Texts to embed
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings of texts
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
"query": query,
|
||||
}
|
||||
scores = await worker_manager.embeddings(params)
|
||||
return scores[0]
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
@@ -456,6 +481,26 @@ async def create_embeddings(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/beta/relevance",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=RelevanceResponse,
|
||||
)
|
||||
async def create_embeddings(
|
||||
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
"""Generate relevance scores for a query and a list of documents."""
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
|
||||
scores = await api_server.relevance_generate(
|
||||
request.model, request.query, request.documents
|
||||
)
|
||||
return model_to_dict(
|
||||
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
Reference in New Issue
Block a user