mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 22:07:48 +00:00
feat(rerank): add Text Embeddings Inference (TEI) API support for reranking (#2516)
Co-authored-by: tam <tanwe@fulan.com.cn>
This commit is contained in:
parent
b715cdb131
commit
1ec855fd79
@ -392,6 +392,107 @@ class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TeiEmbeddingsParameters(OpenAPIRerankerDeployModelParameters):
|
||||||
|
"""Text Embeddings Inference Rerank Embeddings Parameters."""
|
||||||
|
|
||||||
|
provider: str = "proxy/tei"
|
||||||
|
|
||||||
|
api_url: str = field(
|
||||||
|
default="http://localhost:8001/rerank",
|
||||||
|
metadata={
|
||||||
|
"help": _("The URL of the rerank API."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _("The API key for the rerank API."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TeiRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||||
|
"""Text Embeddings Inference Rerank Model.
|
||||||
|
|
||||||
|
See `Text Embeddings Inference API
|
||||||
|
<https://huggingface.github.io/text-embeddings-inference/>`_ for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
"""Initialize the TeiRerankEmbeddings."""
|
||||||
|
# If the API key is not provided, try to get it from the environment
|
||||||
|
if "api_key" not in kwargs:
|
||||||
|
kwargs["api_key"] = os.getenv("TEI_API_KEY")
|
||||||
|
|
||||||
|
if "api_url" not in kwargs:
|
||||||
|
raise ValueError("Please provide the api_url param")
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def param_class(cls) -> Type[TeiEmbeddingsParameters]:
|
||||||
|
"""Get the parameter class."""
|
||||||
|
return TeiEmbeddingsParameters
|
||||||
|
|
||||||
|
def _parse_results(self, response: List[Dict[str, Any]]) -> List[float]:
|
||||||
|
"""Parse the response from the API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response from the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: The rank scores of the candidates.
|
||||||
|
"""
|
||||||
|
if not isinstance(response, list) and len(response) == 0:
|
||||||
|
raise RuntimeError("Results should be a not empty list")
|
||||||
|
|
||||||
|
# Sort by index, 0 in the first element
|
||||||
|
results = sorted(response, key=lambda x: x.get("index", 0))
|
||||||
|
scores = [float(result.get("score")) for result in results]
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||||
|
"""Predict the rank scores of the candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query text.
|
||||||
|
candidates: The list of candidate texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: The rank scores of the candidates.
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
headers = {}
|
||||||
|
current_span_id = root_tracer.get_current_span_id()
|
||||||
|
if self.pass_trace_id and current_span_id:
|
||||||
|
# Set the trace ID if available
|
||||||
|
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
|
||||||
|
data = {"query": query, "texts": candidates}
|
||||||
|
response = self.session.post( # type: ignore
|
||||||
|
self.api_url, json=data, timeout=self.timeout, headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return self._parse_results(response.json())
|
||||||
|
|
||||||
|
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||||
|
"""Predict the rank scores of the candidates asynchronously."""
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
current_span_id = root_tracer.get_current_span_id()
|
||||||
|
if self.pass_trace_id and current_span_id:
|
||||||
|
# Set the trace ID if available
|
||||||
|
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
) as session:
|
||||||
|
data = {"query": query, "texts": candidates}
|
||||||
|
async with session.post(self.api_url, json=data) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
response_data = await resp.json()
|
||||||
|
return self._parse_results(response_data)
|
||||||
|
|
||||||
|
|
||||||
register_embedding_adapter(
|
register_embedding_adapter(
|
||||||
CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||||
)
|
)
|
||||||
@ -401,3 +502,6 @@ register_embedding_adapter(
|
|||||||
register_embedding_adapter(
|
register_embedding_adapter(
|
||||||
SiliconFlowRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
SiliconFlowRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||||
)
|
)
|
||||||
|
register_embedding_adapter(
|
||||||
|
TeiRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user