mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
fix bug - Tongyiproxy embedding model error (#2018)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
8bae10851f
commit
6ce620f4ed
@ -888,11 +888,17 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
||||
super().__init__(**kwargs)
|
||||
self._api_key = kwargs.get("api_key")
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(
|
||||
self, texts: List[str], max_batch_chunks_size=25
|
||||
) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
refer:https://help.aliyun.com/zh/model-studio/getting-started/models?
|
||||
spm=a2c4g.11186623.0.0.62524a77NlILDI#c05fe72732770
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
max_batch_chunks_size: The max batch size for embedding.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
@ -900,17 +906,27 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
from dashscope import TextEmbedding
|
||||
|
||||
# 最多支持10条,每条最长支持2048tokens
|
||||
resp = TextEmbedding.call(
|
||||
model=self.model_name, input=texts, api_key=self._api_key
|
||||
)
|
||||
if "output" not in resp:
|
||||
raise RuntimeError(resp["message"])
|
||||
embeddings = []
|
||||
# batch size too longer may cause embedding error,eg: qwen online embedding
|
||||
# models must not be larger than 25
|
||||
# text-embedding-v3 embedding batch size should not be larger than 6
|
||||
if str(self.model_name) == "text-embedding-v3":
|
||||
max_batch_chunks_size = 6
|
||||
|
||||
embeddings = resp["output"]["embeddings"]
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
|
||||
for i in range(0, len(texts), max_batch_chunks_size):
|
||||
batch_texts = texts[i : i + max_batch_chunks_size]
|
||||
resp = TextEmbedding.call(
|
||||
model=self.model_name, input=batch_texts, api_key=self._api_key
|
||||
)
|
||||
if "output" not in resp:
|
||||
raise RuntimeError(resp["message"])
|
||||
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
# 提取并排序嵌入
|
||||
batch_embeddings = resp["output"]["embeddings"]
|
||||
sorted_embeddings = sorted(batch_embeddings, key=lambda e: e["text_index"])
|
||||
embeddings.extend([result["embedding"] for result in sorted_embeddings])
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a OpenAPI embedding model.
|
||||
|
Loading…
Reference in New Issue
Block a user