DB-GPT/dbgpt/model/cluster/embedding/remote_embedding.py
2024-05-16 14:50:16 +08:00

56 lines
2.1 KiB
Python

from typing import List
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.cluster.manager_base import WorkerManager
class RemoteEmbeddings(Embeddings):
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
self.model_name = model_name
self.worker_manager = worker_manager
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
params = {"model": self.model_name, "input": texts}
return self.worker_manager.sync_embeddings(params)
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
params = {"model": self.model_name, "input": texts}
return await self.worker_manager.embeddings(params)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self.aembed_documents([text])[0]
class RemoteRerankEmbeddings(RerankEmbeddings):
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
self.model_name = model_name
self.worker_manager = worker_manager
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
return self.worker_manager.sync_embeddings(params)[0]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Asynchronously predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
# Use embeddings interface to get scores of ranker
scores = await self.worker_manager.embeddings(params)
# The first element is the scores of the query
return scores[0]