fix bug - Tongyiproxy embedding model error (#2018)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
mzaispace 2024-09-18 21:58:17 +08:00 committed by GitHub
parent 8bae10851f
commit 6ce620f4ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -888,11 +888,17 @@ class TongYiEmbeddings(BaseModel, Embeddings):
super().__init__(**kwargs) super().__init__(**kwargs)
self._api_key = kwargs.get("api_key") self._api_key = kwargs.get("api_key")
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(
self, texts: List[str], max_batch_chunks_size=25
) -> List[List[float]]:
"""Get the embeddings for a list of texts. """Get the embeddings for a list of texts.
refer:https://help.aliyun.com/zh/model-studio/getting-started/models?
spm=a2c4g.11186623.0.0.62524a77NlILDI#c05fe72732770
Args: Args:
texts (Documents): A list of texts to get embeddings for. texts (Documents): A list of texts to get embeddings for.
max_batch_chunks_size: The max batch size for embedding.
Returns: Returns:
Embedded texts as List[List[float]], where each inner List[float] Embedded texts as List[List[float]], where each inner List[float]
@ -900,17 +906,27 @@ class TongYiEmbeddings(BaseModel, Embeddings):
""" """
from dashscope import TextEmbedding from dashscope import TextEmbedding
# 最多支持10条每条最长支持2048tokens embeddings = []
# batch size too longer may cause embedding error,eg: qwen online embedding
# models must not be larger than 25
# text-embedding-v3 embedding batch size should not be larger than 6
if str(self.model_name) == "text-embedding-v3":
max_batch_chunks_size = 6
for i in range(0, len(texts), max_batch_chunks_size):
batch_texts = texts[i : i + max_batch_chunks_size]
resp = TextEmbedding.call( resp = TextEmbedding.call(
model=self.model_name, input=texts, api_key=self._api_key model=self.model_name, input=batch_texts, api_key=self._api_key
) )
if "output" not in resp: if "output" not in resp:
raise RuntimeError(resp["message"]) raise RuntimeError(resp["message"])
embeddings = resp["output"]["embeddings"] # 提取并排序嵌入
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"]) batch_embeddings = resp["output"]["embeddings"]
sorted_embeddings = sorted(batch_embeddings, key=lambda e: e["text_index"])
embeddings.extend([result["embedding"] for result in sorted_embeddings])
return [result["embedding"] for result in sorted_embeddings] return embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model. """Compute query embeddings using a OpenAPI embedding model.