feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -1,6 +1,6 @@
from typing import List
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.cluster.manager_base import WorkerManager
@@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings):
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self.aembed_documents([text])[0]
class RemoteRerankEmbeddings(RerankEmbeddings):
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
self.model_name = model_name
self.worker_manager = worker_manager
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
return self.worker_manager.sync_embeddings(params)[0]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Asynchronously predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
# Use embeddings interface to get scores of ranker
scores = await self.worker_manager.embeddings(params)
# The first element is the scores of the query
return scores[0]