mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(model): Support tongyi embedding (#1552)
Co-authored-by: 无剑 <zhuzhenchao@changeme.com> Co-authored-by: csunny <cfqsunny@163.com> Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
47d205f676
commit
fda1a56ef4
@ -92,6 +92,11 @@ KNOWLEDGE_SEARCH_REWRITE=False
|
|||||||
# proxy_openai_proxy_api_key={your-openai-sk}
|
# proxy_openai_proxy_api_key={your-openai-sk}
|
||||||
# proxy_openai_proxy_backend=text-embedding-ada-002
|
# proxy_openai_proxy_backend=text-embedding-ada-002
|
||||||
|
|
||||||
|
|
||||||
|
## qwen embedding model, See dbgpt/model/parameter.py
|
||||||
|
# EMBEDDING_MODEL=proxy_tongyi
|
||||||
|
# proxy_tongyi_proxy_backend=text-embedding-v1
|
||||||
|
|
||||||
## Common HTTP embedding model
|
## Common HTTP embedding model
|
||||||
# EMBEDDING_MODEL=proxy_http_openapi
|
# EMBEDDING_MODEL=proxy_http_openapi
|
||||||
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
|
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
|
||||||
|
@ -263,6 +263,7 @@ EMBEDDING_MODEL_CONFIG = {
|
|||||||
# Common HTTP embedding model
|
# Common HTTP embedding model
|
||||||
"proxy_http_openapi": "proxy_http_openapi",
|
"proxy_http_openapi": "proxy_http_openapi",
|
||||||
"proxy_ollama": "proxy_ollama",
|
"proxy_ollama": "proxy_ollama",
|
||||||
|
"proxy_tongyi": "proxy_tongyi",
|
||||||
# Rerank model, rerank mode is a special embedding model
|
# Rerank model, rerank mode is a special embedding model
|
||||||
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
||||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||||
|
@ -50,6 +50,14 @@ class EmbeddingLoader:
|
|||||||
if proxy_param.proxy_backend:
|
if proxy_param.proxy_backend:
|
||||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||||
return OpenAPIEmbeddings(**openapi_param)
|
return OpenAPIEmbeddings(**openapi_param)
|
||||||
|
elif model_name in ["proxy_tongyi"]:
|
||||||
|
from dbgpt.rag.embedding import TongYiEmbeddings
|
||||||
|
|
||||||
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||||
|
tongyi_param = {"api_key": proxy_param.proxy_api_key}
|
||||||
|
if proxy_param.proxy_backend:
|
||||||
|
tongyi_param["model_name"] = proxy_param.proxy_backend
|
||||||
|
return TongYiEmbeddings(**tongyi_param)
|
||||||
elif model_name in ["proxy_ollama"]:
|
elif model_name in ["proxy_ollama"]:
|
||||||
from dbgpt.rag.embedding import OllamaEmbeddings
|
from dbgpt.rag.embedding import OllamaEmbeddings
|
||||||
|
|
||||||
|
@ -665,8 +665,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
|||||||
|
|
||||||
|
|
||||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
|
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
|
||||||
"proxy_ollama,rerank_proxy_http_openapi",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||||
|
@ -14,6 +14,7 @@ from .embeddings import ( # noqa: F401
|
|||||||
JinaEmbeddings,
|
JinaEmbeddings,
|
||||||
OllamaEmbeddings,
|
OllamaEmbeddings,
|
||||||
OpenAPIEmbeddings,
|
OpenAPIEmbeddings,
|
||||||
|
TongYiEmbeddings,
|
||||||
)
|
)
|
||||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ __ALL__ = [
|
|||||||
"DefaultEmbeddingFactory",
|
"DefaultEmbeddingFactory",
|
||||||
"EmbeddingFactory",
|
"EmbeddingFactory",
|
||||||
"WrappedEmbeddingFactory",
|
"WrappedEmbeddingFactory",
|
||||||
|
"TongYiEmbeddings",
|
||||||
"CrossEncoderRerankEmbeddings",
|
"CrossEncoderRerankEmbeddings",
|
||||||
"OpenAPIRerankEmbeddings",
|
"OpenAPIRerankEmbeddings",
|
||||||
]
|
]
|
||||||
|
@ -826,3 +826,83 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
return embedding["embedding"]
|
return embedding["embedding"]
|
||||||
except ollama.ResponseError as e:
|
except ollama.ResponseError as e:
|
||||||
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TongYiEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""The tongyi embeddings.
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import TextEmbedding
|
||||||
|
|
||||||
|
dashscope.api_key = ''
|
||||||
|
def embed_with_list_of_str():
|
||||||
|
resp = TextEmbedding.call(
|
||||||
|
model=TextEmbedding.Models.text_embedding_v1,
|
||||||
|
# 最多支持10条,每条最长支持2048tokens
|
||||||
|
input=['风急天高猿啸哀', '渚清沙白鸟飞回', '无边落木萧萧下', '不尽长江滚滚来']
|
||||||
|
)
|
||||||
|
if resp.status_code == HTTPStatus.OK:
|
||||||
|
print(resp)
|
||||||
|
else:
|
||||||
|
print(resp)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
embed_with_list_of_str()
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None, description="The API key for the embeddings API."
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="text-embedding-v1", description="The name of the model to use."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize the OpenAPIEmbeddings."""
|
||||||
|
try:
|
||||||
|
import dashscope # type: ignore
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: dashscope "
|
||||||
|
"Please install dashscope by command `pip install dashscope"
|
||||||
|
) from exc
|
||||||
|
dashscope.TextEmbedding.api_key = kwargs.get("api_key")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._api_key = kwargs.get("api_key")
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Get the embeddings for a list of texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (Documents): A list of texts to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedded texts as List[List[float]], where each inner List[float]
|
||||||
|
corresponds to a single input text.
|
||||||
|
"""
|
||||||
|
from dashscope import TextEmbedding
|
||||||
|
|
||||||
|
# 最多支持10条,每条最长支持2048tokens
|
||||||
|
resp = TextEmbedding.call(
|
||||||
|
model=self.model_name, input=texts, api_key=self._api_key
|
||||||
|
)
|
||||||
|
if "output" not in resp:
|
||||||
|
raise RuntimeError(resp["message"])
|
||||||
|
|
||||||
|
embeddings = resp["output"]["embeddings"]
|
||||||
|
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
|
||||||
|
|
||||||
|
return [result["embedding"] for result in sorted_embeddings]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Compute query embeddings using a OpenAPI embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user