From b344f34635c210780d532b0c44a6231f932edc8e Mon Sep 17 00:00:00 2001 From: Aubrey Ford Date: Fri, 18 Apr 2025 12:27:27 -0700 Subject: [PATCH] partners/openai: OpenAIEmbeddings not respecting chunk_size argument (#30757) When calling `embed_documents` and providing a `chunk_size` argument, that argument is ignored when `OpenAIEmbeddings` is instantiated with its default configuration (where `check_embedding_ctx_length=True`). `_get_len_safe_embeddings` specifies a `chunk_size` parameter but it's not being passed through in `embed_documents`, which is its only caller. This appears to be an oversight, especially given that the `_get_len_safe_embeddings` docstring states it should respect "the set embedding context length and chunk size." Developers typically expect method parameters to take effect (also, take precedence) when explicitly provided, especially when instantiating using defaults. I was confused as to why my API calls were being rejected regardless of the chunk size I provided. This bug also exists in langchain_community package. I can add that to this PR if requested otherwise I will create a new one once this passes. --- .../langchain_openai/embeddings/base.py | 8 ++++++-- .../tests/unit_tests/embeddings/test_base.py | 20 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index e3d03f942b8..826f3563adf 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -573,7 +573,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) - return self._get_len_safe_embeddings(texts, engine=engine) + return self._get_len_safe_embeddings( + texts, engine=engine, chunk_size=chunk_size + ) async def aembed_documents( self, texts: list[str], chunk_size: int | None = None @@ -603,7 +605,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) - return await self._aget_len_safe_embeddings(texts, engine=engine) + return await self._aget_len_safe_embeddings( + texts, engine=engine, chunk_size=chunk_size + ) def embed_query(self, text: str) -> list[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py index d464b9cbb5b..7714b89b12a 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py @@ -20,6 +20,26 @@ def test_openai_incorrect_field() -> None: def test_embed_documents_with_custom_chunk_size() -> None: + embeddings = OpenAIEmbeddings(chunk_size=2) + texts = ["text1", "text2", "text3", "text4"] + custom_chunk_size = 3 + + with patch.object(embeddings.client, "create") as mock_create: + mock_create.side_effect = [ + {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}, + {"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]}, + ] + + result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size) + _, tokens, __ = embeddings._tokenize(texts, custom_chunk_size) + mock_create.call_args + mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params) + mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params) + + assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]] + + +def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None: embeddings = OpenAIEmbeddings(chunk_size=2, check_embedding_ctx_length=False) texts = ["text1", "text2", "text3", "text4"] custom_chunk_size = 3