mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
langchain_community: OpenAIEmbeddings not respecting chunk_size argument (#30946)
This is a follow-on PR to go with the identical changes that were made in parters/openai. Previous PR: https://github.com/langchain-ai/langchain/pull/30757 When calling embed_documents and providing a chunk_size argument, that argument is ignored when OpenAIEmbeddings is instantiated with its default configuration (where check_embedding_ctx_length=True). _get_len_safe_embeddings specifies a chunk_size parameter but it's not being passed through in embed_documents, which is its only caller. This appears to be an oversight, especially given that the _get_len_safe_embeddings docstring states it should respect "the set embedding context length and chunk size." Developers typically expect method parameters to take effect (also, take precedence) when explicitly provided, especially when instantiating using defaults. I was confused as to why my API calls were being rejected regardless of the chunk size I provided.
This commit is contained in:
parent
b344f34635
commit
23f701b08e
@ -668,7 +668,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||||
# than the maximum context and use length-safe embedding function.
|
# than the maximum context and use length-safe embedding function.
|
||||||
engine = cast(str, self.deployment)
|
engine = cast(str, self.deployment)
|
||||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
return self._get_len_safe_embeddings(
|
||||||
|
texts, engine=engine, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
async def aembed_documents(
|
async def aembed_documents(
|
||||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
@ -686,7 +688,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||||
# than the maximum context and use length-safe embedding function.
|
# than the maximum context and use length-safe embedding function.
|
||||||
engine = cast(str, self.deployment)
|
engine = cast(str, self.deployment)
|
||||||
return await self._aget_len_safe_embeddings(texts, engine=engine)
|
return self._get_len_safe_embeddings(
|
||||||
|
texts, engine=engine, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_openai_invalid_model_kwargs() -> None:
|
def test_openai_invalid_model_kwargs() -> None:
|
||||||
@ -14,3 +19,20 @@ def test_openai_incorrect_field() -> None:
|
|||||||
with pytest.warns(match="not default parameter"):
|
with pytest.warns(match="not default parameter"):
|
||||||
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
|
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
|
||||||
assert llm.model_kwargs == {"foo": "bar"}
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_embed_documents_with_custom_chunk_size() -> None:
|
||||||
|
embeddings = OpenAIEmbeddings(chunk_size=2)
|
||||||
|
texts = ["text1", "text2", "text3", "text4"]
|
||||||
|
custom_chunk_size = 3
|
||||||
|
|
||||||
|
with patch.object(embeddings.client, "create") as mock_create:
|
||||||
|
mock_create.side_effect = [
|
||||||
|
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
|
||||||
|
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||||
|
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
|
||||||
|
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
|
||||||
|
Loading…
Reference in New Issue
Block a user