diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 2011f53e5..a677df1be 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -322,6 +322,7 @@ EMBEDDING_MODEL_CONFIG = { "bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"), # Proxy rerank model "rerank_proxy_http_openapi": "rerank_proxy_http_openapi", + "rerank_proxy_silicon_flow": "rerank_proxy_silicon_flow", } diff --git a/dbgpt/model/adapter/embeddings_loader.py b/dbgpt/model/adapter/embeddings_loader.py index 892469962..d26df89d5 100644 --- a/dbgpt/model/adapter/embeddings_loader.py +++ b/dbgpt/model/adapter/embeddings_loader.py @@ -109,6 +109,18 @@ class EmbeddingLoader: if proxy_param.proxy_backend: openapi_param["model_name"] = proxy_param.proxy_backend return OpenAPIRerankEmbeddings(**openapi_param) + elif model_name in ["rerank_proxy_silicon_flow"]: + from dbgpt.rag.embedding.rerank import SiliconFlowRerankEmbeddings + + proxy_param = cast(ProxyEmbeddingParameters, param) + openapi_param = {} + if proxy_param.proxy_server_url: + openapi_param["api_url"] = proxy_param.proxy_server_url + if proxy_param.proxy_api_key: + openapi_param["api_key"] = proxy_param.proxy_api_key + if proxy_param.proxy_backend: + openapi_param["model_name"] = proxy_param.proxy_backend + return SiliconFlowRerankEmbeddings(**openapi_param) else: from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 065f74124..7ed7e5b3a 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -613,7 +613,16 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters): _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { - ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi", + ProxyEmbeddingParameters: [ + "proxy_openai", + "proxy_azure", + "proxy_http_openapi", + "proxy_ollama", + "proxy_tongyi", + "proxy_qianfan", + "rerank_proxy_http_openapi", + "rerank_proxy_silicon_flow", + ] } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} @@ -622,7 +631,6 @@ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} def _update_embedding_config(): global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items(): - models = [m.strip() for m in models.split(",")] for model in models: if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG: EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index 435769098..70c4515be 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -17,7 +17,11 @@ from .embeddings import ( # noqa: F401 QianFanEmbeddings, TongYiEmbeddings, ) -from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401 +from .rerank import ( # noqa: F401 + CrossEncoderRerankEmbeddings, + OpenAPIRerankEmbeddings, + SiliconFlowRerankEmbeddings, +) __ALL__ = [ "CrossEncoderRerankEmbeddings", @@ -32,6 +36,7 @@ __ALL__ = [ "OllamaEmbeddings", "OpenAPIEmbeddings", "OpenAPIRerankEmbeddings", + "SiliconFlowRerankEmbeddings", "QianFanEmbeddings", "TongYiEmbeddings", "WrappedEmbeddingFactory", diff --git a/dbgpt/rag/embedding/rerank.py b/dbgpt/rag/embedding/rerank.py index b797d0fa7..61150e17b 100644 --- a/dbgpt/rag/embedding/rerank.py +++ b/dbgpt/rag/embedding/rerank.py @@ -1,5 +1,6 @@ """Re-rank embeddings.""" +import os from typing import Any, Dict, List, Optional, cast import aiohttp @@ -104,6 +105,24 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings): kwargs["session"] = session super().__init__(**kwargs) + def _parse_results(self, response: Dict[str, Any]) -> List[float]: + """Parse the response from the API. + + Args: + response: The response from the API. + + Returns: + List[float]: The rank scores of the candidates. + """ + data = response.get("data") + if not data: + if "detail" in response: + raise RuntimeError(response["detail"]) + raise RuntimeError("Cannot find results in the response") + if not isinstance(data, list): + raise RuntimeError("Results should be a list") + return data + def predict(self, query: str, candidates: List[str]) -> List[float]: """Predict the rank scores of the candidates. @@ -126,7 +145,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings): self.api_url, json=data, timeout=self.timeout, headers=headers ) response.raise_for_status() - return response.json()["data"] + return self._parse_results(response.json()) async def apredict(self, query: str, candidates: List[str]) -> List[float]: """Predict the rank scores of the candidates asynchronously.""" @@ -142,6 +161,50 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings): async with session.post(self.api_url, json=data) as resp: resp.raise_for_status() response_data = await resp.json() - if "data" not in response_data: - raise RuntimeError(response_data["detail"]) - return response_data["data"] + return self._parse_results(response_data) + + +class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings): + """SiliconFlow Rerank Model. + + See `SiliconFlow API + `_ for more details. + """ + + def __init__(self, **kwargs: Any): + """Initialize the SiliconFlowRerankEmbeddings.""" + # If the API key is not provided, try to get it from the environment + if "api_key" not in kwargs: + kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY") + + if "api_url" not in kwargs: + env_api_url = os.getenv("SILICON_FLOW_API_BASE") + if env_api_url: + env_api_url = env_api_url.rstrip("/") + kwargs["api_url"] = env_api_url + "/rerank" + else: + kwargs["api_url"] = "https://api.siliconflow.cn/v1/rerank" + + if "model_name" not in kwargs: + kwargs["model_name"] = "BAAI/bge-reranker-v2-m3" + + super().__init__(**kwargs) + + def _parse_results(self, response: Dict[str, Any]) -> List[float]: + """Parse the response from the API. + + Args: + response: The response from the API. + + Returns: + List[float]: The rank scores of the candidates. + """ + results = response.get("results") + if not results: + raise RuntimeError("Cannot find results in the response") + if not isinstance(results, list): + raise RuntimeError("Results should be a list") + # Sort by index, 0 in the first element + results = sorted(results, key=lambda x: x.get("index", 0)) + scores = [float(result.get("relevance_score")) for result in results] + return scores