diff --git a/libs/community/langchain_community/embeddings/openai.py b/libs/community/langchain_community/embeddings/openai.py index 126fe564ae5..5bedfc747ee 100644 --- a/libs/community/langchain_community/embeddings/openai.py +++ b/libs/community/langchain_community/embeddings/openai.py @@ -668,7 +668,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) - return self._get_len_safe_embeddings(texts, engine=engine) + return self._get_len_safe_embeddings( + texts, engine=engine, chunk_size=chunk_size + ) async def aembed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 @@ -686,7 +688,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) - return await self._aget_len_safe_embeddings(texts, engine=engine) + return self._get_len_safe_embeddings( + texts, engine=engine, chunk_size=chunk_size + ) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/libs/community/tests/unit_tests/embeddings/test_openai.py b/libs/community/tests/unit_tests/embeddings/test_openai.py index e62bee5f945..8d349e67387 100644 --- a/libs/community/tests/unit_tests/embeddings/test_openai.py +++ b/libs/community/tests/unit_tests/embeddings/test_openai.py @@ -1,7 +1,12 @@ +import os +from unittest.mock import patch + import pytest from langchain_community.embeddings.openai import OpenAIEmbeddings +os.environ["OPENAI_API_KEY"] = "foo" + @pytest.mark.requires("openai") def test_openai_invalid_model_kwargs() -> None: @@ -14,3 +19,20 @@ def test_openai_incorrect_field() -> None: with pytest.warns(match="not default parameter"): llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg] assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("openai") +def test_embed_documents_with_custom_chunk_size() -> None: + embeddings = OpenAIEmbeddings(chunk_size=2) + texts = ["text1", "text2", "text3", "text4"] + custom_chunk_size = 3 + + with patch.object(embeddings.client, "create") as mock_create: + mock_create.side_effect = [ + {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}, + {"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]}, + ] + + embeddings.embed_documents(texts, chunk_size=custom_chunk_size) + mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params) + mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)