diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index f81fd3c0e..6735765a0 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -888,11 +888,17 @@ class TongYiEmbeddings(BaseModel, Embeddings): super().__init__(**kwargs) self._api_key = kwargs.get("api_key") - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents( + self, texts: List[str], max_batch_chunks_size=25 + ) -> List[List[float]]: """Get the embeddings for a list of texts. + refer:https://help.aliyun.com/zh/model-studio/getting-started/models? + spm=a2c4g.11186623.0.0.62524a77NlILDI#c05fe72732770 + Args: texts (Documents): A list of texts to get embeddings for. + max_batch_chunks_size: The max batch size for embedding. Returns: Embedded texts as List[List[float]], where each inner List[float] @@ -900,17 +906,27 @@ class TongYiEmbeddings(BaseModel, Embeddings): """ from dashscope import TextEmbedding - # 最多支持10条,每条最长支持2048tokens - resp = TextEmbedding.call( - model=self.model_name, input=texts, api_key=self._api_key - ) - if "output" not in resp: - raise RuntimeError(resp["message"]) + embeddings = [] + # batch size too longer may cause embedding error,eg: qwen online embedding + # models must not be larger than 25 + # text-embedding-v3 embedding batch size should not be larger than 6 + if str(self.model_name) == "text-embedding-v3": + max_batch_chunks_size = 6 - embeddings = resp["output"]["embeddings"] - sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"]) + for i in range(0, len(texts), max_batch_chunks_size): + batch_texts = texts[i : i + max_batch_chunks_size] + resp = TextEmbedding.call( + model=self.model_name, input=batch_texts, api_key=self._api_key + ) + if "output" not in resp: + raise RuntimeError(resp["message"]) - return [result["embedding"] for result in sorted_embeddings] + # 提取并排序嵌入 + batch_embeddings = resp["output"]["embeddings"] + sorted_embeddings = sorted(batch_embeddings, key=lambda e: e["text_index"]) + embeddings.extend([result["embedding"] for result in sorted_embeddings]) + + return embeddings def embed_query(self, text: str) -> List[float]: """Compute query embeddings using a OpenAPI embedding model.