mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 04:08:10 +00:00
fix bug - Tongyiproxy embedding model error (#2018)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
8bae10851f
commit
6ce620f4ed
@ -888,11 +888,17 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._api_key = kwargs.get("api_key")
|
self._api_key = kwargs.get("api_key")
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(
|
||||||
|
self, texts: List[str], max_batch_chunks_size=25
|
||||||
|
) -> List[List[float]]:
|
||||||
"""Get the embeddings for a list of texts.
|
"""Get the embeddings for a list of texts.
|
||||||
|
|
||||||
|
refer:https://help.aliyun.com/zh/model-studio/getting-started/models?
|
||||||
|
spm=a2c4g.11186623.0.0.62524a77NlILDI#c05fe72732770
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts (Documents): A list of texts to get embeddings for.
|
texts (Documents): A list of texts to get embeddings for.
|
||||||
|
max_batch_chunks_size: The max batch size for embedding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedded texts as List[List[float]], where each inner List[float]
|
Embedded texts as List[List[float]], where each inner List[float]
|
||||||
@ -900,17 +906,27 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
from dashscope import TextEmbedding
|
from dashscope import TextEmbedding
|
||||||
|
|
||||||
# 最多支持10条,每条最长支持2048tokens
|
embeddings = []
|
||||||
|
# batch size too longer may cause embedding error,eg: qwen online embedding
|
||||||
|
# models must not be larger than 25
|
||||||
|
# text-embedding-v3 embedding batch size should not be larger than 6
|
||||||
|
if str(self.model_name) == "text-embedding-v3":
|
||||||
|
max_batch_chunks_size = 6
|
||||||
|
|
||||||
|
for i in range(0, len(texts), max_batch_chunks_size):
|
||||||
|
batch_texts = texts[i : i + max_batch_chunks_size]
|
||||||
resp = TextEmbedding.call(
|
resp = TextEmbedding.call(
|
||||||
model=self.model_name, input=texts, api_key=self._api_key
|
model=self.model_name, input=batch_texts, api_key=self._api_key
|
||||||
)
|
)
|
||||||
if "output" not in resp:
|
if "output" not in resp:
|
||||||
raise RuntimeError(resp["message"])
|
raise RuntimeError(resp["message"])
|
||||||
|
|
||||||
embeddings = resp["output"]["embeddings"]
|
# 提取并排序嵌入
|
||||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
|
batch_embeddings = resp["output"]["embeddings"]
|
||||||
|
sorted_embeddings = sorted(batch_embeddings, key=lambda e: e["text_index"])
|
||||||
|
embeddings.extend([result["embedding"] for result in sorted_embeddings])
|
||||||
|
|
||||||
return [result["embedding"] for result in sorted_embeddings]
|
return embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Compute query embeddings using a OpenAPI embedding model.
|
"""Compute query embeddings using a OpenAPI embedding model.
|
||||||
|
Loading…
Reference in New Issue
Block a user