feat(rerank): add Text Embeddings Inference (TEI) API support for reranking (#2516)

Co-authored-by: tam <tanwe@fulan.com.cn>
This commit is contained in:
Tam 2025-03-24 23:16:03 +08:00 committed by GitHub
parent b715cdb131
commit 1ec855fd79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -392,6 +392,107 @@ class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
return scores
@dataclass
class TeiEmbeddingsParameters(OpenAPIRerankerDeployModelParameters):
"""Text Embeddings Inference Rerank Embeddings Parameters."""
provider: str = "proxy/tei"
api_url: str = field(
default="http://localhost:8001/rerank",
metadata={
"help": _("The URL of the rerank API."),
},
)
api_key: Optional[str] = field(
default=None,
metadata={
"help": _("The API key for the rerank API."),
},
)
class TeiRerankEmbeddings(OpenAPIRerankEmbeddings):
"""Text Embeddings Inference Rerank Model.
See `Text Embeddings Inference API
<https://huggingface.github.io/text-embeddings-inference/>`_ for more details.
"""
def __init__(self, **kwargs: Any):
"""Initialize the TeiRerankEmbeddings."""
# If the API key is not provided, try to get it from the environment
if "api_key" not in kwargs:
kwargs["api_key"] = os.getenv("TEI_API_KEY")
if "api_url" not in kwargs:
raise ValueError("Please provide the api_url param")
super().__init__(**kwargs)
@classmethod
def param_class(cls) -> Type[TeiEmbeddingsParameters]:
"""Get the parameter class."""
return TeiEmbeddingsParameters
def _parse_results(self, response: List[Dict[str, Any]]) -> List[float]:
"""Parse the response from the API.
Args:
response: The response from the API.
Returns:
List[float]: The rank scores of the candidates.
"""
if not isinstance(response, list) and len(response) == 0:
raise RuntimeError("Results should be a not empty list")
# Sort by index, 0 in the first element
results = sorted(response, key=lambda x: x.get("index", 0))
scores = [float(result.get("score")) for result in results]
return scores
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
if not candidates:
return []
headers = {}
current_span_id = root_tracer.get_current_span_id()
if self.pass_trace_id and current_span_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
data = {"query": query, "texts": candidates}
response = self.session.post( # type: ignore
self.api_url, json=data, timeout=self.timeout, headers=headers
)
response.raise_for_status()
return self._parse_results(response.json())
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
headers = {"Authorization": f"Bearer {self.api_key}"}
current_span_id = root_tracer.get_current_span_id()
if self.pass_trace_id and current_span_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
data = {"query": query, "texts": candidates}
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
return self._parse_results(response_data)
register_embedding_adapter(
CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
)
@ -401,3 +502,6 @@ register_embedding_adapter(
register_embedding_adapter(
SiliconFlowRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
)
register_embedding_adapter(
TeiRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
)