mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 16:50:24 +00:00
56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
from typing import List
|
|
|
|
from dbgpt.core import Embeddings, RerankEmbeddings
|
|
from dbgpt.model.cluster.manager_base import WorkerManager
|
|
|
|
|
|
class RemoteEmbeddings(Embeddings):
|
|
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
|
|
self.model_name = model_name
|
|
self.worker_manager = worker_manager
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
params = {"model": self.model_name, "input": texts}
|
|
return self.worker_manager.sync_embeddings(params)
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed query text."""
|
|
return self.embed_documents([text])[0]
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Asynchronous Embed search docs."""
|
|
params = {"model": self.model_name, "input": texts}
|
|
return await self.worker_manager.embeddings(params)
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Asynchronous Embed query text."""
|
|
return await self.aembed_documents([text])[0]
|
|
|
|
|
|
class RemoteRerankEmbeddings(RerankEmbeddings):
|
|
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
|
|
self.model_name = model_name
|
|
self.worker_manager = worker_manager
|
|
|
|
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
|
"""Predict the scores of the candidates."""
|
|
params = {
|
|
"model": self.model_name,
|
|
"input": candidates,
|
|
"query": query,
|
|
}
|
|
return self.worker_manager.sync_embeddings(params)[0]
|
|
|
|
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
|
"""Asynchronously predict the scores of the candidates."""
|
|
params = {
|
|
"model": self.model_name,
|
|
"input": candidates,
|
|
"query": query,
|
|
}
|
|
# Use embeddings interface to get scores of ranker
|
|
scores = await self.worker_manager.embeddings(params)
|
|
# The first element is the scores of the query
|
|
return scores[0]
|