mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 03:41:43 +00:00
feat(model): Support tongyi embedding (#1552)
Co-authored-by: 无剑 <zhuzhenchao@changeme.com> Co-authored-by: csunny <cfqsunny@163.com> Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
47d205f676
commit
fda1a56ef4
@ -92,6 +92,11 @@ KNOWLEDGE_SEARCH_REWRITE=False
|
||||
# proxy_openai_proxy_api_key={your-openai-sk}
|
||||
# proxy_openai_proxy_backend=text-embedding-ada-002
|
||||
|
||||
|
||||
## qwen embedding model, See dbgpt/model/parameter.py
|
||||
# EMBEDDING_MODEL=proxy_tongyi
|
||||
# proxy_tongyi_proxy_backend=text-embedding-v1
|
||||
|
||||
## Common HTTP embedding model
|
||||
# EMBEDDING_MODEL=proxy_http_openapi
|
||||
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
|
||||
|
@ -263,6 +263,7 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
# Common HTTP embedding model
|
||||
"proxy_http_openapi": "proxy_http_openapi",
|
||||
"proxy_ollama": "proxy_ollama",
|
||||
"proxy_tongyi": "proxy_tongyi",
|
||||
# Rerank model, rerank mode is a special embedding model
|
||||
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||
|
@ -50,6 +50,14 @@ class EmbeddingLoader:
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIEmbeddings(**openapi_param)
|
||||
elif model_name in ["proxy_tongyi"]:
|
||||
from dbgpt.rag.embedding import TongYiEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
tongyi_param = {"api_key": proxy_param.proxy_api_key}
|
||||
if proxy_param.proxy_backend:
|
||||
tongyi_param["model_name"] = proxy_param.proxy_backend
|
||||
return TongYiEmbeddings(**tongyi_param)
|
||||
elif model_name in ["proxy_ollama"]:
|
||||
from dbgpt.rag.embedding import OllamaEmbeddings
|
||||
|
||||
|
@ -665,8 +665,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
|
||||
"proxy_ollama,rerank_proxy_http_openapi",
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
|
@ -14,6 +14,7 @@ from .embeddings import ( # noqa: F401
|
||||
JinaEmbeddings,
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
TongYiEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
|
||||
@ -29,6 +30,7 @@ __ALL__ = [
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
"TongYiEmbeddings",
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
]
|
||||
|
@ -826,3 +826,83 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
return embedding["embedding"]
|
||||
except ollama.ResponseError as e:
|
||||
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
||||
|
||||
|
||||
class TongYiEmbeddings(BaseModel, Embeddings):
|
||||
"""The tongyi embeddings.
|
||||
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from dashscope import TextEmbedding
|
||||
|
||||
dashscope.api_key = ''
|
||||
def embed_with_list_of_str():
|
||||
resp = TextEmbedding.call(
|
||||
model=TextEmbedding.Models.text_embedding_v1,
|
||||
# 最多支持10条,每条最长支持2048tokens
|
||||
input=['风急天高猿啸哀', '渚清沙白鸟飞回', '无边落木萧萧下', '不尽长江滚滚来']
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
print(resp)
|
||||
else:
|
||||
print(resp)
|
||||
|
||||
if __name__ == '__main__':
|
||||
embed_with_list_of_str()
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-v1", description="The name of the model to use."
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
try:
|
||||
import dashscope # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: dashscope "
|
||||
"Please install dashscope by command `pip install dashscope"
|
||||
) from exc
|
||||
dashscope.TextEmbedding.api_key = kwargs.get("api_key")
|
||||
super().__init__(**kwargs)
|
||||
self._api_key = kwargs.get("api_key")
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
"""
|
||||
from dashscope import TextEmbedding
|
||||
|
||||
# 最多支持10条,每条最长支持2048tokens
|
||||
resp = TextEmbedding.call(
|
||||
model=self.model_name, input=texts, api_key=self._api_key
|
||||
)
|
||||
if "output" not in resp:
|
||||
raise RuntimeError(resp["message"])
|
||||
|
||||
embeddings = resp["output"]["embeddings"]
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
|
||||
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a OpenAPI embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
Loading…
Reference in New Issue
Block a user