mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-27 12:49:34 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager
|
||||
|
||||
|
||||
@@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings):
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self.aembed_documents([text])[0]
|
||||
|
||||
|
||||
class RemoteRerankEmbeddings(RerankEmbeddings):
|
||||
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
|
||||
self.model_name = model_name
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the scores of the candidates."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"input": candidates,
|
||||
"query": query,
|
||||
}
|
||||
return self.worker_manager.sync_embeddings(params)[0]
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Asynchronously predict the scores of the candidates."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"input": candidates,
|
||||
"query": query,
|
||||
}
|
||||
# Use embeddings interface to get scores of ranker
|
||||
scores = await self.worker_manager.embeddings(params)
|
||||
# The first element is the scores of the query
|
||||
return scores[0]
|
||||
|
||||
Reference in New Issue
Block a user