diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 252996dc8b3..4e2f7c74cfb 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -573,7 +573,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): chunk_size_ = chunk_size or self.chunk_size if not self.check_embedding_ctx_length: embeddings: List[List[float]] = [] - for i in range(0, len(texts), self.chunk_size): + for i in range(0, len(texts), chunk_size_): response = self.client.create( input=texts[i : i + chunk_size_], **self._invocation_params ) diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py index 3bb5a61e4a9..d464b9cbb5b 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py @@ -1,4 +1,5 @@ import os +from unittest.mock import patch import pytest @@ -16,3 +17,23 @@ def test_openai_incorrect_field() -> None: with pytest.warns(match="not default parameter"): llm = OpenAIEmbeddings(foo="bar") # type: ignore[call-arg] assert llm.model_kwargs == {"foo": "bar"} + + +def test_embed_documents_with_custom_chunk_size() -> None: + embeddings = OpenAIEmbeddings(chunk_size=2, check_embedding_ctx_length=False) + texts = ["text1", "text2", "text3", "text4"] + custom_chunk_size = 3 + + with patch.object(embeddings.client, "create") as mock_create: + mock_create.side_effect = [ + {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}, + {"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]}, + ] + + result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size) + + mock_create.call_args + mock_create.assert_any_call(input=texts[0:3], **embeddings._invocation_params) + mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params) + + assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]