mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(rerank): add Text Embeddings Inference (TEI) API support for reranking (#2516)
Co-authored-by: tam <tanwe@fulan.com.cn>
This commit is contained in:
parent
b715cdb131
commit
1ec855fd79
@ -392,6 +392,107 @@ class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||
return scores
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeiEmbeddingsParameters(OpenAPIRerankerDeployModelParameters):
|
||||
"""Text Embeddings Inference Rerank Embeddings Parameters."""
|
||||
|
||||
provider: str = "proxy/tei"
|
||||
|
||||
api_url: str = field(
|
||||
default="http://localhost:8001/rerank",
|
||||
metadata={
|
||||
"help": _("The URL of the rerank API."),
|
||||
},
|
||||
)
|
||||
api_key: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("The API key for the rerank API."),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TeiRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||
"""Text Embeddings Inference Rerank Model.
|
||||
|
||||
See `Text Embeddings Inference API
|
||||
<https://huggingface.github.io/text-embeddings-inference/>`_ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the TeiRerankEmbeddings."""
|
||||
# If the API key is not provided, try to get it from the environment
|
||||
if "api_key" not in kwargs:
|
||||
kwargs["api_key"] = os.getenv("TEI_API_KEY")
|
||||
|
||||
if "api_url" not in kwargs:
|
||||
raise ValueError("Please provide the api_url param")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def param_class(cls) -> Type[TeiEmbeddingsParameters]:
|
||||
"""Get the parameter class."""
|
||||
return TeiEmbeddingsParameters
|
||||
|
||||
def _parse_results(self, response: List[Dict[str, Any]]) -> List[float]:
|
||||
"""Parse the response from the API.
|
||||
|
||||
Args:
|
||||
response: The response from the API.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
if not isinstance(response, list) and len(response) == 0:
|
||||
raise RuntimeError("Results should be a not empty list")
|
||||
|
||||
# Sort by index, 0 in the first element
|
||||
results = sorted(response, key=lambda x: x.get("index", 0))
|
||||
scores = [float(result.get("score")) for result in results]
|
||||
return scores
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
headers = {}
|
||||
current_span_id = root_tracer.get_current_span_id()
|
||||
if self.pass_trace_id and current_span_id:
|
||||
# Set the trace ID if available
|
||||
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
|
||||
data = {"query": query, "texts": candidates}
|
||||
response = self.session.post( # type: ignore
|
||||
self.api_url, json=data, timeout=self.timeout, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._parse_results(response.json())
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
current_span_id = root_tracer.get_current_span_id()
|
||||
if self.pass_trace_id and current_span_id:
|
||||
# Set the trace ID if available
|
||||
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
data = {"query": query, "texts": candidates}
|
||||
async with session.post(self.api_url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
response_data = await resp.json()
|
||||
return self._parse_results(response_data)
|
||||
|
||||
|
||||
register_embedding_adapter(
|
||||
CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||
)
|
||||
@ -401,3 +502,6 @@ register_embedding_adapter(
|
||||
register_embedding_adapter(
|
||||
SiliconFlowRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||
)
|
||||
register_embedding_adapter(
|
||||
TeiRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user