feat(model): Support siliconflow rerank models (#2188)

This commit is contained in:
Fangyin Cheng
2024-12-11 18:36:44 +08:00
committed by GitHub
parent 23aedea092
commit abab4e3e65
5 changed files with 96 additions and 7 deletions

View File

@@ -17,7 +17,11 @@ from .embeddings import ( # noqa: F401
QianFanEmbeddings,
TongYiEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
from .rerank import ( # noqa: F401
CrossEncoderRerankEmbeddings,
OpenAPIRerankEmbeddings,
SiliconFlowRerankEmbeddings,
)
__ALL__ = [
"CrossEncoderRerankEmbeddings",
@@ -32,6 +36,7 @@ __ALL__ = [
"OllamaEmbeddings",
"OpenAPIEmbeddings",
"OpenAPIRerankEmbeddings",
"SiliconFlowRerankEmbeddings",
"QianFanEmbeddings",
"TongYiEmbeddings",
"WrappedEmbeddingFactory",

View File

@@ -1,5 +1,6 @@
"""Re-rank embeddings."""
import os
from typing import Any, Dict, List, Optional, cast
import aiohttp
@@ -104,6 +105,24 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
kwargs["session"] = session
super().__init__(**kwargs)
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
"""Parse the response from the API.
Args:
response: The response from the API.
Returns:
List[float]: The rank scores of the candidates.
"""
data = response.get("data")
if not data:
if "detail" in response:
raise RuntimeError(response["detail"])
raise RuntimeError("Cannot find results in the response")
if not isinstance(data, list):
raise RuntimeError("Results should be a list")
return data
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
@@ -126,7 +145,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
self.api_url, json=data, timeout=self.timeout, headers=headers
)
response.raise_for_status()
return response.json()["data"]
return self._parse_results(response.json())
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
@@ -142,6 +161,50 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
if "data" not in response_data:
raise RuntimeError(response_data["detail"])
return response_data["data"]
return self._parse_results(response_data)
class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
"""SiliconFlow Rerank Model.
See `SiliconFlow API
<https://docs.siliconflow.cn/api-reference/rerank/create-rerank>`_ for more details.
"""
def __init__(self, **kwargs: Any):
"""Initialize the SiliconFlowRerankEmbeddings."""
# If the API key is not provided, try to get it from the environment
if "api_key" not in kwargs:
kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY")
if "api_url" not in kwargs:
env_api_url = os.getenv("SILICON_FLOW_API_BASE")
if env_api_url:
env_api_url = env_api_url.rstrip("/")
kwargs["api_url"] = env_api_url + "/rerank"
else:
kwargs["api_url"] = "https://api.siliconflow.cn/v1/rerank"
if "model_name" not in kwargs:
kwargs["model_name"] = "BAAI/bge-reranker-v2-m3"
super().__init__(**kwargs)
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
"""Parse the response from the API.
Args:
response: The response from the API.
Returns:
List[float]: The rank scores of the candidates.
"""
results = response.get("results")
if not results:
raise RuntimeError("Cannot find results in the response")
if not isinstance(results, list):
raise RuntimeError("Results should be a list")
# Sort by index, 0 in the first element
results = sorted(results, key=lambda x: x.get("index", 0))
scores = [float(result.get("relevance_score")) for result in results]
return scores