mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
feat(model): Support siliconflow rerank models (#2188)
This commit is contained in:
@@ -17,7 +17,11 @@ from .embeddings import ( # noqa: F401
|
||||
QianFanEmbeddings,
|
||||
TongYiEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
from .rerank import ( # noqa: F401
|
||||
CrossEncoderRerankEmbeddings,
|
||||
OpenAPIRerankEmbeddings,
|
||||
SiliconFlowRerankEmbeddings,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
@@ -32,6 +36,7 @@ __ALL__ = [
|
||||
"OllamaEmbeddings",
|
||||
"OpenAPIEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
"SiliconFlowRerankEmbeddings",
|
||||
"QianFanEmbeddings",
|
||||
"TongYiEmbeddings",
|
||||
"WrappedEmbeddingFactory",
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""Re-rank embeddings."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
@@ -104,6 +105,24 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
kwargs["session"] = session
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
|
||||
"""Parse the response from the API.
|
||||
|
||||
Args:
|
||||
response: The response from the API.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
data = response.get("data")
|
||||
if not data:
|
||||
if "detail" in response:
|
||||
raise RuntimeError(response["detail"])
|
||||
raise RuntimeError("Cannot find results in the response")
|
||||
if not isinstance(data, list):
|
||||
raise RuntimeError("Results should be a list")
|
||||
return data
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
@@ -126,7 +145,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
self.api_url, json=data, timeout=self.timeout, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
return self._parse_results(response.json())
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
@@ -142,6 +161,50 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
async with session.post(self.api_url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
response_data = await resp.json()
|
||||
if "data" not in response_data:
|
||||
raise RuntimeError(response_data["detail"])
|
||||
return response_data["data"]
|
||||
return self._parse_results(response_data)
|
||||
|
||||
|
||||
class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||
"""SiliconFlow Rerank Model.
|
||||
|
||||
See `SiliconFlow API
|
||||
<https://docs.siliconflow.cn/api-reference/rerank/create-rerank>`_ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the SiliconFlowRerankEmbeddings."""
|
||||
# If the API key is not provided, try to get it from the environment
|
||||
if "api_key" not in kwargs:
|
||||
kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY")
|
||||
|
||||
if "api_url" not in kwargs:
|
||||
env_api_url = os.getenv("SILICON_FLOW_API_BASE")
|
||||
if env_api_url:
|
||||
env_api_url = env_api_url.rstrip("/")
|
||||
kwargs["api_url"] = env_api_url + "/rerank"
|
||||
else:
|
||||
kwargs["api_url"] = "https://api.siliconflow.cn/v1/rerank"
|
||||
|
||||
if "model_name" not in kwargs:
|
||||
kwargs["model_name"] = "BAAI/bge-reranker-v2-m3"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
|
||||
"""Parse the response from the API.
|
||||
|
||||
Args:
|
||||
response: The response from the API.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
results = response.get("results")
|
||||
if not results:
|
||||
raise RuntimeError("Cannot find results in the response")
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError("Results should be a list")
|
||||
# Sort by index, 0 in the first element
|
||||
results = sorted(results, key=lambda x: x.get("index", 0))
|
||||
scores = [float(result.get("relevance_score")) for result in results]
|
||||
return scores
|
||||
|
Reference in New Issue
Block a user