mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(model): Support qianfan embedding and resolves an issue (#1819)
This commit is contained in:
@@ -97,6 +97,14 @@ KNOWLEDGE_SEARCH_REWRITE=False
|
||||
## qwen embedding model, See dbgpt/model/parameter.py
|
||||
# EMBEDDING_MODEL=proxy_tongyi
|
||||
# proxy_tongyi_proxy_backend=text-embedding-v1
|
||||
# proxy_tongyi_proxy_api_key={your-api-key}
|
||||
|
||||
## qianfan embedding model, See dbgpt/model/parameter.py
|
||||
#EMBEDDING_MODEL=proxy_qianfan
|
||||
#proxy_qianfan_proxy_backend=bge-large-zh
|
||||
#proxy_qianfan_proxy_api_key={your-api-key}
|
||||
#proxy_qianfan_proxy_api_secret={your-secret-key}
|
||||
|
||||
|
||||
## Common HTTP embedding model
|
||||
# EMBEDDING_MODEL=proxy_http_openapi
|
||||
|
@@ -290,6 +290,7 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
"proxy_http_openapi": "proxy_http_openapi",
|
||||
"proxy_ollama": "proxy_ollama",
|
||||
"proxy_tongyi": "proxy_tongyi",
|
||||
"proxy_qianfan": "proxy_qianfan",
|
||||
# Rerank model, rerank mode is a special embedding model
|
||||
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||
|
@@ -58,6 +58,14 @@ class EmbeddingLoader:
|
||||
if proxy_param.proxy_backend:
|
||||
tongyi_param["model_name"] = proxy_param.proxy_backend
|
||||
return TongYiEmbeddings(**tongyi_param)
|
||||
elif model_name in ["proxy_qianfan"]:
|
||||
from dbgpt.rag.embedding import QianFanEmbeddings
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
qianfan_param = {"api_key": proxy_param.proxy_api_key}
|
||||
if proxy_param.proxy_backend:
|
||||
qianfan_param["model_name"] = proxy_param.proxy_backend
|
||||
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
|
||||
return QianFanEmbeddings(**qianfan_param)
|
||||
elif model_name in ["proxy_ollama"]:
|
||||
from dbgpt.rag.embedding import OllamaEmbeddings
|
||||
|
||||
|
@@ -558,6 +558,13 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
|
||||
},
|
||||
)
|
||||
proxy_api_secret: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"tags": "privacy",
|
||||
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
|
||||
}
|
||||
)
|
||||
proxy_api_version: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -603,7 +610,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
|
@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
TongYiEmbeddings,
|
||||
QianFanEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
|
||||
@@ -33,4 +34,5 @@ __ALL__ = [
|
||||
"TongYiEmbeddings",
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
"QianFanEmbeddings"
|
||||
]
|
||||
|
@@ -922,3 +922,100 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
"""Baidu Qianfan Embeddings embedding models.
|
||||
Embed:
|
||||
.. code-block:: python
|
||||
|
||||
# embed the documents
|
||||
vectors = embeddings.embed_documents([text1, text2, ...])
|
||||
|
||||
# embed the query
|
||||
vectors = embeddings.embed_query(text)
|
||||
|
||||
""" # noqa: E501
|
||||
client: Any
|
||||
chunk_size: int = 16
|
||||
endpoint: str = ""
|
||||
"""Endpoint of the Qianfan Embedding, required if custom model used."""
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
api_secret: Optional[str] = Field(
|
||||
default=None, description="The Secret key for the embeddings API."
|
||||
)
|
||||
"""Model name
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
for now, we support Embedding-V1 and
|
||||
- Embedding-V1 (默认模型)
|
||||
- bge-large-en
|
||||
- bge-large-zh
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
"""
|
||||
model_name: str = Field(
|
||||
default="text-embedding-v1", description="The name of the model to use."
|
||||
)
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
associated with qianfan resource object to limit QPS"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra params for model invoke using with `do`."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the QianFanEmbeddings."""
|
||||
try:
|
||||
import qianfan
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: qianfan. "
|
||||
"Please install qianfan by running `pip install qianfan`."
|
||||
) from exc
|
||||
|
||||
qianfan_ak = kwargs.get("api_key")
|
||||
qianfan_sk = kwargs.get("api_secret")
|
||||
model_name = kwargs.get("model_name")
|
||||
|
||||
if not qianfan_ak or not qianfan_sk or not model_name:
|
||||
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")
|
||||
|
||||
params = {
|
||||
"model": model_name,
|
||||
"ak": qianfan_ak,
|
||||
"sk": qianfan_sk,
|
||||
}
|
||||
|
||||
# Initialize the qianfan.Embedding client
|
||||
kwargs["client"] = qianfan.Embedding(**params)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
resp = self.embed_documents([text])
|
||||
return resp[0]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Embeds a list of text documents using the AutoVOT algorithm.
|
||||
|
||||
Args:
|
||||
texts (List[str]): A list of text documents to embed.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings for each document in the input list.
|
||||
Each embedding is represented as a list of float values.
|
||||
"""
|
||||
text_in_chunks = [
|
||||
texts[i: i + self.chunk_size]
|
||||
for i in range(0, len(texts), self.chunk_size)
|
||||
]
|
||||
lst = []
|
||||
for chunk in text_in_chunks:
|
||||
resp = self.client.do(texts=chunk, **self.model_kwargs)
|
||||
lst.extend([res["embedding"] for res in resp["data"]])
|
||||
return lst
|
||||
|
Reference in New Issue
Block a user