mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(model): Support qianfan embedding and resolves an issue (#1819)
This commit is contained in:
@@ -97,6 +97,14 @@ KNOWLEDGE_SEARCH_REWRITE=False
|
|||||||
## qwen embedding model, See dbgpt/model/parameter.py
|
## qwen embedding model, See dbgpt/model/parameter.py
|
||||||
# EMBEDDING_MODEL=proxy_tongyi
|
# EMBEDDING_MODEL=proxy_tongyi
|
||||||
# proxy_tongyi_proxy_backend=text-embedding-v1
|
# proxy_tongyi_proxy_backend=text-embedding-v1
|
||||||
|
# proxy_tongyi_proxy_api_key={your-api-key}
|
||||||
|
|
||||||
|
## qianfan embedding model, See dbgpt/model/parameter.py
|
||||||
|
#EMBEDDING_MODEL=proxy_qianfan
|
||||||
|
#proxy_qianfan_proxy_backend=bge-large-zh
|
||||||
|
#proxy_qianfan_proxy_api_key={your-api-key}
|
||||||
|
#proxy_qianfan_proxy_api_secret={your-secret-key}
|
||||||
|
|
||||||
|
|
||||||
## Common HTTP embedding model
|
## Common HTTP embedding model
|
||||||
# EMBEDDING_MODEL=proxy_http_openapi
|
# EMBEDDING_MODEL=proxy_http_openapi
|
||||||
@@ -309,4 +317,4 @@ DBGPT_LOG_LEVEL=INFO
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
#** FINANCIAL CHAT Config **#
|
#** FINANCIAL CHAT Config **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# FIN_REPORT_MODEL=/app/models/bge-large-zh
|
# FIN_REPORT_MODEL=/app/models/bge-large-zh
|
||||||
|
@@ -290,6 +290,7 @@ EMBEDDING_MODEL_CONFIG = {
|
|||||||
"proxy_http_openapi": "proxy_http_openapi",
|
"proxy_http_openapi": "proxy_http_openapi",
|
||||||
"proxy_ollama": "proxy_ollama",
|
"proxy_ollama": "proxy_ollama",
|
||||||
"proxy_tongyi": "proxy_tongyi",
|
"proxy_tongyi": "proxy_tongyi",
|
||||||
|
"proxy_qianfan": "proxy_qianfan",
|
||||||
# Rerank model, rerank mode is a special embedding model
|
# Rerank model, rerank mode is a special embedding model
|
||||||
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
||||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||||
|
@@ -58,6 +58,14 @@ class EmbeddingLoader:
|
|||||||
if proxy_param.proxy_backend:
|
if proxy_param.proxy_backend:
|
||||||
tongyi_param["model_name"] = proxy_param.proxy_backend
|
tongyi_param["model_name"] = proxy_param.proxy_backend
|
||||||
return TongYiEmbeddings(**tongyi_param)
|
return TongYiEmbeddings(**tongyi_param)
|
||||||
|
elif model_name in ["proxy_qianfan"]:
|
||||||
|
from dbgpt.rag.embedding import QianFanEmbeddings
|
||||||
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||||
|
qianfan_param = {"api_key": proxy_param.proxy_api_key}
|
||||||
|
if proxy_param.proxy_backend:
|
||||||
|
qianfan_param["model_name"] = proxy_param.proxy_backend
|
||||||
|
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
|
||||||
|
return QianFanEmbeddings(**qianfan_param)
|
||||||
elif model_name in ["proxy_ollama"]:
|
elif model_name in ["proxy_ollama"]:
|
||||||
from dbgpt.rag.embedding import OllamaEmbeddings
|
from dbgpt.rag.embedding import OllamaEmbeddings
|
||||||
|
|
||||||
|
@@ -558,6 +558,13 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
|||||||
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
|
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
proxy_api_secret: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"tags": "privacy",
|
||||||
|
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
|
||||||
|
}
|
||||||
|
)
|
||||||
proxy_api_version: Optional[str] = field(
|
proxy_api_version: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -603,7 +610,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
|||||||
|
|
||||||
|
|
||||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
|
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||||
|
@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
|
|||||||
OllamaEmbeddings,
|
OllamaEmbeddings,
|
||||||
OpenAPIEmbeddings,
|
OpenAPIEmbeddings,
|
||||||
TongYiEmbeddings,
|
TongYiEmbeddings,
|
||||||
|
QianFanEmbeddings,
|
||||||
)
|
)
|
||||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||||
|
|
||||||
@@ -33,4 +34,5 @@ __ALL__ = [
|
|||||||
"TongYiEmbeddings",
|
"TongYiEmbeddings",
|
||||||
"CrossEncoderRerankEmbeddings",
|
"CrossEncoderRerankEmbeddings",
|
||||||
"OpenAPIRerankEmbeddings",
|
"OpenAPIRerankEmbeddings",
|
||||||
|
"QianFanEmbeddings"
|
||||||
]
|
]
|
||||||
|
@@ -922,3 +922,100 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self.embed_documents([text])[0]
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
|
||||||
|
class QianFanEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Baidu Qianfan Embeddings embedding models.
|
||||||
|
Embed:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# embed the documents
|
||||||
|
vectors = embeddings.embed_documents([text1, text2, ...])
|
||||||
|
|
||||||
|
# embed the query
|
||||||
|
vectors = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
client: Any
|
||||||
|
chunk_size: int = 16
|
||||||
|
endpoint: str = ""
|
||||||
|
"""Endpoint of the Qianfan Embedding, required if custom model used."""
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None, description="The API key for the embeddings API."
|
||||||
|
)
|
||||||
|
api_secret: Optional[str] = Field(
|
||||||
|
default=None, description="The Secret key for the embeddings API."
|
||||||
|
)
|
||||||
|
"""Model name
|
||||||
|
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||||
|
|
||||||
|
for now, we support Embedding-V1 and
|
||||||
|
- Embedding-V1 (默认模型)
|
||||||
|
- bge-large-en
|
||||||
|
- bge-large-zh
|
||||||
|
|
||||||
|
preset models are mapping to an endpoint.
|
||||||
|
`model` will be ignored if `endpoint` is set
|
||||||
|
"""
|
||||||
|
model_name: str = Field(
|
||||||
|
default="text-embedding-v1", description="The name of the model to use."
|
||||||
|
)
|
||||||
|
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||||
|
associated with qianfan resource object to limit QPS"""
|
||||||
|
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""extra params for model invoke using with `do`."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize the QianFanEmbeddings."""
|
||||||
|
try:
|
||||||
|
import qianfan
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: qianfan. "
|
||||||
|
"Please install qianfan by running `pip install qianfan`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
qianfan_ak = kwargs.get("api_key")
|
||||||
|
qianfan_sk = kwargs.get("api_secret")
|
||||||
|
model_name = kwargs.get("model_name")
|
||||||
|
|
||||||
|
if not qianfan_ak or not qianfan_sk or not model_name:
|
||||||
|
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"model": model_name,
|
||||||
|
"ak": qianfan_ak,
|
||||||
|
"sk": qianfan_sk,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize the qianfan.Embedding client
|
||||||
|
kwargs["client"] = qianfan.Embedding(**params)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
resp = self.embed_documents([text])
|
||||||
|
return resp[0]
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Embeds a list of text documents using the AutoVOT algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): A list of text documents to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: A list of embeddings for each document in the input list.
|
||||||
|
Each embedding is represented as a list of float values.
|
||||||
|
"""
|
||||||
|
text_in_chunks = [
|
||||||
|
texts[i: i + self.chunk_size]
|
||||||
|
for i in range(0, len(texts), self.chunk_size)
|
||||||
|
]
|
||||||
|
lst = []
|
||||||
|
for chunk in text_in_chunks:
|
||||||
|
resp = self.client.do(texts=chunk, **self.model_kwargs)
|
||||||
|
lst.extend([res["embedding"] for res in resp["data"]])
|
||||||
|
return lst
|
||||||
|
1
setup.py
1
setup.py
@@ -687,6 +687,7 @@ def default_requires():
|
|||||||
"chardet",
|
"chardet",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"ollama",
|
"ollama",
|
||||||
|
"qianfan"
|
||||||
]
|
]
|
||||||
setup_spec.extras["default"] += setup_spec.extras["framework"]
|
setup_spec.extras["default"] += setup_spec.extras["framework"]
|
||||||
setup_spec.extras["default"] += setup_spec.extras["rag"]
|
setup_spec.extras["default"] += setup_spec.extras["rag"]
|
||||||
|
Reference in New Issue
Block a user