From 1ec855fd792dfcb259c900f9eb1fd9d573357009 Mon Sep 17 00:00:00 2001 From: Tam Date: Mon, 24 Mar 2025 23:16:03 +0800 Subject: [PATCH] feat(rerank): add Text Embeddings Inference (TEI) API support for reranking (#2516) Co-authored-by: tam --- .../src/dbgpt/rag/embedding/rerank.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py b/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py index 23c3da50c..ee57402b8 100644 --- a/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py +++ b/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py @@ -392,6 +392,107 @@ class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings): return scores +@dataclass +class TeiEmbeddingsParameters(OpenAPIRerankerDeployModelParameters): + """Text Embeddings Inference Rerank Embeddings Parameters.""" + + provider: str = "proxy/tei" + + api_url: str = field( + default="http://localhost:8001/rerank", + metadata={ + "help": _("The URL of the rerank API."), + }, + ) + api_key: Optional[str] = field( + default=None, + metadata={ + "help": _("The API key for the rerank API."), + }, + ) + + +class TeiRerankEmbeddings(OpenAPIRerankEmbeddings): + """Text Embeddings Inference Rerank Model. + + See `Text Embeddings Inference API + `_ for more details. + """ + + def __init__(self, **kwargs: Any): + """Initialize the TeiRerankEmbeddings.""" + # If the API key is not provided, try to get it from the environment + if "api_key" not in kwargs: + kwargs["api_key"] = os.getenv("TEI_API_KEY") + + if "api_url" not in kwargs: + raise ValueError("Please provide the api_url param") + + super().__init__(**kwargs) + + @classmethod + def param_class(cls) -> Type[TeiEmbeddingsParameters]: + """Get the parameter class.""" + return TeiEmbeddingsParameters + + def _parse_results(self, response: List[Dict[str, Any]]) -> List[float]: + """Parse the response from the API. + + Args: + response: The response from the API. + + Returns: + List[float]: The rank scores of the candidates. + """ + if not isinstance(response, list) and len(response) == 0: + raise RuntimeError("Results should be a not empty list") + + # Sort by index, 0 in the first element + results = sorted(response, key=lambda x: x.get("index", 0)) + scores = [float(result.get("score")) for result in results] + return scores + + def predict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the rank scores of the candidates. + + Args: + query: The query text. + candidates: The list of candidate texts. + + Returns: + List[float]: The rank scores of the candidates. + """ + if not candidates: + return [] + headers = {} + current_span_id = root_tracer.get_current_span_id() + if self.pass_trace_id and current_span_id: + # Set the trace ID if available + headers[DBGPT_TRACER_SPAN_ID] = current_span_id + data = {"query": query, "texts": candidates} + response = self.session.post( # type: ignore + self.api_url, json=data, timeout=self.timeout, headers=headers + ) + response.raise_for_status() + return self._parse_results(response.json()) + + async def apredict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the rank scores of the candidates asynchronously.""" + headers = {"Authorization": f"Bearer {self.api_key}"} + current_span_id = root_tracer.get_current_span_id() + if self.pass_trace_id and current_span_id: + # Set the trace ID if available + headers[DBGPT_TRACER_SPAN_ID] = current_span_id + async with aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as session: + data = {"query": query, "texts": candidates} + async with session.post(self.api_url, json=data) as resp: + resp.raise_for_status() + response_data = await resp.json() + return self._parse_results(response_data) + + register_embedding_adapter( CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS ) @@ -401,3 +502,6 @@ register_embedding_adapter( register_embedding_adapter( SiliconFlowRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS ) +register_embedding_adapter( + TeiRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS +)