diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index e3d03f942b8..826f3563adf 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -573,7 +573,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) - return self._get_len_safe_embeddings(texts, engine=engine) + return self._get_len_safe_embeddings( + texts, engine=engine, chunk_size=chunk_size + ) async def aembed_documents( self, texts: list[str], chunk_size: int | None = None @@ -603,7 +605,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) - return await self._aget_len_safe_embeddings(texts, engine=engine) + return await self._aget_len_safe_embeddings( + texts, engine=engine, chunk_size=chunk_size + ) def embed_query(self, text: str) -> list[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py index d464b9cbb5b..7714b89b12a 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py @@ -20,6 +20,26 @@ def test_openai_incorrect_field() -> None: def test_embed_documents_with_custom_chunk_size() -> None: + embeddings = OpenAIEmbeddings(chunk_size=2) + texts = ["text1", "text2", "text3", "text4"] + custom_chunk_size = 3 + + with patch.object(embeddings.client, "create") as mock_create: + mock_create.side_effect = [ + {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}, + {"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]}, + ] + + result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size) + _, tokens, __ = embeddings._tokenize(texts, custom_chunk_size) + mock_create.call_args + mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params) + mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params) + + assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]] + + +def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None: embeddings = OpenAIEmbeddings(chunk_size=2, check_embedding_ctx_length=False) texts = ["text1", "text2", "text3", "text4"] custom_chunk_size = 3