langchain_community: OpenAIEmbeddings not respecting chunk_size argument (#30946)

This is a follow-on PR to go with the identical changes that were made
in parters/openai.

Previous PR:  https://github.com/langchain-ai/langchain/pull/30757

When calling embed_documents and providing a chunk_size argument, that
argument is ignored when OpenAIEmbeddings is instantiated with its
default configuration (where check_embedding_ctx_length=True).

_get_len_safe_embeddings specifies a chunk_size parameter but it's not
being passed through in embed_documents, which is its only caller. This
appears to be an oversight, especially given that the
_get_len_safe_embeddings docstring states it should respect "the set
embedding context length and chunk size."

Developers typically expect method parameters to take effect (also, take
precedence) when explicitly provided, especially when instantiating
using defaults. I was confused as to why my API calls were being
rejected regardless of the chunk size I provided.
This commit is contained in:
Aubrey Ford 2025-04-21 05:39:07 -07:00 committed by GitHub
parent b344f34635
commit 23f701b08e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 2 deletions

View File

@ -668,7 +668,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine)
return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
)
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
@ -686,7 +688,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine)
return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
)
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.

View File

@ -1,7 +1,12 @@
import os
from unittest.mock import patch
import pytest
from langchain_community.embeddings.openai import OpenAIEmbeddings
os.environ["OPENAI_API_KEY"] = "foo"
@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
@ -14,3 +19,20 @@ def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("openai")
def test_embed_documents_with_custom_chunk_size() -> None:
embeddings = OpenAIEmbeddings(chunk_size=2)
texts = ["text1", "text2", "text3", "text4"]
custom_chunk_size = 3
with patch.object(embeddings.client, "create") as mock_create:
mock_create.side_effect = [
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
]
embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)