feat(model): Support tongyi embedding (#1552)

Co-authored-by: 无剑 <zhuzhenchao@changeme.com>
Co-authored-by: csunny <cfqsunny@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
无剑 2024-06-25 19:50:38 +08:00 committed by GitHub
parent 47d205f676
commit fda1a56ef4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 97 additions and 2 deletions

View File

@ -92,6 +92,11 @@ KNOWLEDGE_SEARCH_REWRITE=False
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002
## qwen embedding model, See dbgpt/model/parameter.py
# EMBEDDING_MODEL=proxy_tongyi
# proxy_tongyi_proxy_backend=text-embedding-v1
## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings

View File

@ -263,6 +263,7 @@ EMBEDDING_MODEL_CONFIG = {
# Common HTTP embedding model
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
"proxy_tongyi": "proxy_tongyi",
# Rerank model, rerank mode is a special embedding model
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),

View File

@ -50,6 +50,14 @@ class EmbeddingLoader:
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
elif model_name in ["proxy_tongyi"]:
from dbgpt.rag.embedding import TongYiEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
tongyi_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
tongyi_param["model_name"] = proxy_param.proxy_backend
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings

View File

@ -665,8 +665,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
"proxy_ollama,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}

View File

@ -14,6 +14,7 @@ from .embeddings import ( # noqa: F401
JinaEmbeddings,
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
@ -29,6 +30,7 @@ __ALL__ = [
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
]

View File

@ -826,3 +826,83 @@ class OllamaEmbeddings(BaseModel, Embeddings):
return embedding["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
class TongYiEmbeddings(BaseModel, Embeddings):
"""The tongyi embeddings.
import dashscope
from http import HTTPStatus
from dashscope import TextEmbedding
dashscope.api_key = ''
def embed_with_list_of_str():
resp = TextEmbedding.call(
model=TextEmbedding.Models.text_embedding_v1,
# 最多支持10条每条最长支持2048tokens
input=['风急天高猿啸哀', '渚清沙白鸟飞回', '无边落木萧萧下', '不尽长江滚滚来']
)
if resp.status_code == HTTPStatus.OK:
print(resp)
else:
print(resp)
if __name__ == '__main__':
embed_with_list_of_str()
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import dashscope # type: ignore
except ImportError as exc:
raise ValueError(
"Could not import python package: dashscope "
"Please install dashscope by command `pip install dashscope"
) from exc
dashscope.TextEmbedding.api_key = kwargs.get("api_key")
super().__init__(**kwargs)
self._api_key = kwargs.get("api_key")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
from dashscope import TextEmbedding
# 最多支持10条每条最长支持2048tokens
resp = TextEmbedding.call(
model=self.model_name, input=texts, api_key=self._api_key
)
if "output" not in resp:
raise RuntimeError(resp["message"])
embeddings = resp["output"]["embeddings"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
return [result["embedding"] for result in sorted_embeddings]
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]