feat(model): Support qianfan embedding and resolves an issue (#1819)

This commit is contained in:
dusens
2024-08-29 15:52:00 +08:00
committed by GitHub
parent 549f256524
commit 51b4327fdf
7 changed files with 126 additions and 2 deletions

View File

@@ -97,6 +97,14 @@ KNOWLEDGE_SEARCH_REWRITE=False
## qwen embedding model, See dbgpt/model/parameter.py
# EMBEDDING_MODEL=proxy_tongyi
# proxy_tongyi_proxy_backend=text-embedding-v1
# proxy_tongyi_proxy_api_key={your-api-key}
## qianfan embedding model, See dbgpt/model/parameter.py
#EMBEDDING_MODEL=proxy_qianfan
#proxy_qianfan_proxy_backend=bge-large-zh
#proxy_qianfan_proxy_api_key={your-api-key}
#proxy_qianfan_proxy_api_secret={your-secret-key}
## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi

View File

@@ -290,6 +290,7 @@ EMBEDDING_MODEL_CONFIG = {
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
"proxy_tongyi": "proxy_tongyi",
"proxy_qianfan": "proxy_qianfan",
# Rerank model, rerank mode is a special embedding model
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),

View File

@@ -58,6 +58,14 @@ class EmbeddingLoader:
if proxy_param.proxy_backend:
tongyi_param["model_name"] = proxy_param.proxy_backend
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
qianfan_param["model_name"] = proxy_param.proxy_backend
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
return QianFanEmbeddings(**qianfan_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings

View File

@@ -558,6 +558,13 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_secret: str = field(
default=None,
metadata={
"tags": "privacy",
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
}
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
@@ -603,7 +610,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}

View File

@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
QianFanEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
@@ -33,4 +34,5 @@ __ALL__ = [
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
"QianFanEmbeddings"
]

View File

@@ -922,3 +922,100 @@ class TongYiEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
return self.embed_documents([text])[0]
class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models.
Embed:
.. code-block:: python
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
# embed the query
vectors = embeddings.embed_query(text)
""" # noqa: E501
client: Any
chunk_size: int = 16
endpoint: str = ""
"""Endpoint of the Qianfan Embedding, required if custom model used."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
api_secret: Optional[str] = Field(
default=None, description="The Secret key for the embeddings API."
)
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""
def __init__(self, **kwargs):
"""Initialize the QianFanEmbeddings."""
try:
import qianfan
except ImportError as exc:
raise ValueError(
"Could not import python package: qianfan. "
"Please install qianfan by running `pip install qianfan`."
) from exc
qianfan_ak = kwargs.get("api_key")
qianfan_sk = kwargs.get("api_secret")
model_name = kwargs.get("model_name")
if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")
params = {
"model": model_name,
"ak": qianfan_ak,
"sk": qianfan_sk,
}
# Initialize the qianfan.Embedding client
kwargs["client"] = qianfan.Embedding(**params)
super().__init__(**kwargs)
def embed_query(self, text: str) -> List[float]:
resp = self.embed_documents([text])
return resp[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts (List[str]): A list of text documents to embed.
Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i: i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk, **self.model_kwargs)
lst.extend([res["embedding"] for res in resp["data"]])
return lst

View File

@@ -687,6 +687,7 @@ def default_requires():
"chardet",
"sentencepiece",
"ollama",
"qianfan"
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"]