fix bug - Tongyiproxy embedding model error (#2018)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
mzaispace 2024-09-18 21:58:17 +08:00 committed by GitHub
parent 8bae10851f
commit 6ce620f4ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -888,11 +888,17 @@ class TongYiEmbeddings(BaseModel, Embeddings):
super().__init__(**kwargs)
self._api_key = kwargs.get("api_key")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(
self, texts: List[str], max_batch_chunks_size=25
) -> List[List[float]]:
"""Get the embeddings for a list of texts.
refer:https://help.aliyun.com/zh/model-studio/getting-started/models?
spm=a2c4g.11186623.0.0.62524a77NlILDI#c05fe72732770
Args:
texts (Documents): A list of texts to get embeddings for.
max_batch_chunks_size: The max batch size for embedding.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
@ -900,17 +906,27 @@ class TongYiEmbeddings(BaseModel, Embeddings):
"""
from dashscope import TextEmbedding
# 最多支持10条每条最长支持2048tokens
resp = TextEmbedding.call(
model=self.model_name, input=texts, api_key=self._api_key
)
if "output" not in resp:
raise RuntimeError(resp["message"])
embeddings = []
# batch size too longer may cause embedding error,eg: qwen online embedding
# models must not be larger than 25
# text-embedding-v3 embedding batch size should not be larger than 6
if str(self.model_name) == "text-embedding-v3":
max_batch_chunks_size = 6
embeddings = resp["output"]["embeddings"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
for i in range(0, len(texts), max_batch_chunks_size):
batch_texts = texts[i : i + max_batch_chunks_size]
resp = TextEmbedding.call(
model=self.model_name, input=batch_texts, api_key=self._api_key
)
if "output" not in resp:
raise RuntimeError(resp["message"])
return [result["embedding"] for result in sorted_embeddings]
# 提取并排序嵌入
batch_embeddings = resp["output"]["embeddings"]
sorted_embeddings = sorted(batch_embeddings, key=lambda e: e["text_index"])
embeddings.extend([result["embedding"] for result in sorted_embeddings])
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.