mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 01:50:08 +00:00
feat(model): Support siliconflow rerank models (#2188)
This commit is contained in:
parent
23aedea092
commit
abab4e3e65
@ -322,6 +322,7 @@ EMBEDDING_MODEL_CONFIG = {
|
|||||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||||
# Proxy rerank model
|
# Proxy rerank model
|
||||||
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
|
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
|
||||||
|
"rerank_proxy_silicon_flow": "rerank_proxy_silicon_flow",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,6 +109,18 @@ class EmbeddingLoader:
|
|||||||
if proxy_param.proxy_backend:
|
if proxy_param.proxy_backend:
|
||||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||||
return OpenAPIRerankEmbeddings(**openapi_param)
|
return OpenAPIRerankEmbeddings(**openapi_param)
|
||||||
|
elif model_name in ["rerank_proxy_silicon_flow"]:
|
||||||
|
from dbgpt.rag.embedding.rerank import SiliconFlowRerankEmbeddings
|
||||||
|
|
||||||
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||||
|
openapi_param = {}
|
||||||
|
if proxy_param.proxy_server_url:
|
||||||
|
openapi_param["api_url"] = proxy_param.proxy_server_url
|
||||||
|
if proxy_param.proxy_api_key:
|
||||||
|
openapi_param["api_key"] = proxy_param.proxy_api_key
|
||||||
|
if proxy_param.proxy_backend:
|
||||||
|
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||||
|
return SiliconFlowRerankEmbeddings(**openapi_param)
|
||||||
else:
|
else:
|
||||||
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
||||||
|
|
||||||
|
@ -613,7 +613,16 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
|||||||
|
|
||||||
|
|
||||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
|
ProxyEmbeddingParameters: [
|
||||||
|
"proxy_openai",
|
||||||
|
"proxy_azure",
|
||||||
|
"proxy_http_openapi",
|
||||||
|
"proxy_ollama",
|
||||||
|
"proxy_tongyi",
|
||||||
|
"proxy_qianfan",
|
||||||
|
"rerank_proxy_http_openapi",
|
||||||
|
"rerank_proxy_silicon_flow",
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||||
@ -622,7 +631,6 @@ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
|||||||
def _update_embedding_config():
|
def _update_embedding_config():
|
||||||
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
|
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
|
||||||
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
|
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
|
||||||
models = [m.strip() for m in models.split(",")]
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
|
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
|
||||||
|
@ -17,7 +17,11 @@ from .embeddings import ( # noqa: F401
|
|||||||
QianFanEmbeddings,
|
QianFanEmbeddings,
|
||||||
TongYiEmbeddings,
|
TongYiEmbeddings,
|
||||||
)
|
)
|
||||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
from .rerank import ( # noqa: F401
|
||||||
|
CrossEncoderRerankEmbeddings,
|
||||||
|
OpenAPIRerankEmbeddings,
|
||||||
|
SiliconFlowRerankEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
__ALL__ = [
|
__ALL__ = [
|
||||||
"CrossEncoderRerankEmbeddings",
|
"CrossEncoderRerankEmbeddings",
|
||||||
@ -32,6 +36,7 @@ __ALL__ = [
|
|||||||
"OllamaEmbeddings",
|
"OllamaEmbeddings",
|
||||||
"OpenAPIEmbeddings",
|
"OpenAPIEmbeddings",
|
||||||
"OpenAPIRerankEmbeddings",
|
"OpenAPIRerankEmbeddings",
|
||||||
|
"SiliconFlowRerankEmbeddings",
|
||||||
"QianFanEmbeddings",
|
"QianFanEmbeddings",
|
||||||
"TongYiEmbeddings",
|
"TongYiEmbeddings",
|
||||||
"WrappedEmbeddingFactory",
|
"WrappedEmbeddingFactory",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Re-rank embeddings."""
|
"""Re-rank embeddings."""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -104,6 +105,24 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
|||||||
kwargs["session"] = session
|
kwargs["session"] = session
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
|
||||||
|
"""Parse the response from the API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response from the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: The rank scores of the candidates.
|
||||||
|
"""
|
||||||
|
data = response.get("data")
|
||||||
|
if not data:
|
||||||
|
if "detail" in response:
|
||||||
|
raise RuntimeError(response["detail"])
|
||||||
|
raise RuntimeError("Cannot find results in the response")
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise RuntimeError("Results should be a list")
|
||||||
|
return data
|
||||||
|
|
||||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||||
"""Predict the rank scores of the candidates.
|
"""Predict the rank scores of the candidates.
|
||||||
|
|
||||||
@ -126,7 +145,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
|||||||
self.api_url, json=data, timeout=self.timeout, headers=headers
|
self.api_url, json=data, timeout=self.timeout, headers=headers
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()["data"]
|
return self._parse_results(response.json())
|
||||||
|
|
||||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||||
"""Predict the rank scores of the candidates asynchronously."""
|
"""Predict the rank scores of the candidates asynchronously."""
|
||||||
@ -142,6 +161,50 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
|||||||
async with session.post(self.api_url, json=data) as resp:
|
async with session.post(self.api_url, json=data) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
response_data = await resp.json()
|
response_data = await resp.json()
|
||||||
if "data" not in response_data:
|
return self._parse_results(response_data)
|
||||||
raise RuntimeError(response_data["detail"])
|
|
||||||
return response_data["data"]
|
|
||||||
|
class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||||
|
"""SiliconFlow Rerank Model.
|
||||||
|
|
||||||
|
See `SiliconFlow API
|
||||||
|
<https://docs.siliconflow.cn/api-reference/rerank/create-rerank>`_ for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
"""Initialize the SiliconFlowRerankEmbeddings."""
|
||||||
|
# If the API key is not provided, try to get it from the environment
|
||||||
|
if "api_key" not in kwargs:
|
||||||
|
kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY")
|
||||||
|
|
||||||
|
if "api_url" not in kwargs:
|
||||||
|
env_api_url = os.getenv("SILICON_FLOW_API_BASE")
|
||||||
|
if env_api_url:
|
||||||
|
env_api_url = env_api_url.rstrip("/")
|
||||||
|
kwargs["api_url"] = env_api_url + "/rerank"
|
||||||
|
else:
|
||||||
|
kwargs["api_url"] = "https://api.siliconflow.cn/v1/rerank"
|
||||||
|
|
||||||
|
if "model_name" not in kwargs:
|
||||||
|
kwargs["model_name"] = "BAAI/bge-reranker-v2-m3"
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
|
||||||
|
"""Parse the response from the API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response from the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: The rank scores of the candidates.
|
||||||
|
"""
|
||||||
|
results = response.get("results")
|
||||||
|
if not results:
|
||||||
|
raise RuntimeError("Cannot find results in the response")
|
||||||
|
if not isinstance(results, list):
|
||||||
|
raise RuntimeError("Results should be a list")
|
||||||
|
# Sort by index, 0 in the first element
|
||||||
|
results = sorted(results, key=lambda x: x.get("index", 0))
|
||||||
|
scores = [float(result.get("relevance_score")) for result in results]
|
||||||
|
return scores
|
||||||
|
Loading…
Reference in New Issue
Block a user