partners/openai: OpenAIEmbeddings not respecting chunk_size argument (#30757)

When calling `embed_documents` and providing a `chunk_size` argument,
that argument is ignored when `OpenAIEmbeddings` is instantiated with
its default configuration (where `check_embedding_ctx_length=True`).

`_get_len_safe_embeddings` specifies a `chunk_size` parameter but it's
not being passed through in `embed_documents`, which is its only caller.
This appears to be an oversight, especially given that the
`_get_len_safe_embeddings` docstring states it should respect "the set
embedding context length and chunk size."

Developers typically expect method parameters to take effect (also, take
precedence) when explicitly provided, especially when instantiating
using defaults. I was confused as to why my API calls were being
rejected regardless of the chunk size I provided.

This bug also exists in langchain_community package. I can add that to
this PR if requested otherwise I will create a new one once this passes.
This commit is contained in:
Aubrey Ford 2025-04-18 12:27:27 -07:00 committed by GitHub
parent 017c8079e1
commit b344f34635
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 2 deletions

View File

@ -573,7 +573,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine)
return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
)
async def aembed_documents(
self, texts: list[str], chunk_size: int | None = None
@ -603,7 +605,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine)
return await self._aget_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
)
def embed_query(self, text: str) -> list[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.

View File

@ -20,6 +20,26 @@ def test_openai_incorrect_field() -> None:
def test_embed_documents_with_custom_chunk_size() -> None:
embeddings = OpenAIEmbeddings(chunk_size=2)
texts = ["text1", "text2", "text3", "text4"]
custom_chunk_size = 3
with patch.object(embeddings.client, "create") as mock_create:
mock_create.side_effect = [
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
]
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
mock_create.call_args
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]
def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
embeddings = OpenAIEmbeddings(chunk_size=2, check_embedding_ctx_length=False)
texts = ["text1", "text2", "text3", "text4"]
custom_chunk_size = 3