mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
partners/openai: OpenAIEmbeddings not respecting chunk_size argument (#30757)
When calling `embed_documents` and providing a `chunk_size` argument, that argument is ignored when `OpenAIEmbeddings` is instantiated with its default configuration (where `check_embedding_ctx_length=True`). `_get_len_safe_embeddings` specifies a `chunk_size` parameter but it's not being passed through in `embed_documents`, which is its only caller. This appears to be an oversight, especially given that the `_get_len_safe_embeddings` docstring states it should respect "the set embedding context length and chunk size." Developers typically expect method parameters to take effect (also, take precedence) when explicitly provided, especially when instantiating using defaults. I was confused as to why my API calls were being rejected regardless of the chunk size I provided. This bug also exists in langchain_community package. I can add that to this PR if requested otherwise I will create a new one once this passes.
This commit is contained in:
parent
017c8079e1
commit
b344f34635
@ -573,7 +573,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
||||
return self._get_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = None
|
||||
@ -603,7 +605,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
return await self._aget_len_safe_embeddings(texts, engine=engine)
|
||||
return await self._aget_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
@ -20,6 +20,26 @@ def test_openai_incorrect_field() -> None:
|
||||
|
||||
|
||||
def test_embed_documents_with_custom_chunk_size() -> None:
|
||||
embeddings = OpenAIEmbeddings(chunk_size=2)
|
||||
texts = ["text1", "text2", "text3", "text4"]
|
||||
custom_chunk_size = 3
|
||||
|
||||
with patch.object(embeddings.client, "create") as mock_create:
|
||||
mock_create.side_effect = [
|
||||
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
|
||||
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
|
||||
]
|
||||
|
||||
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
|
||||
mock_create.call_args
|
||||
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
|
||||
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]
|
||||
|
||||
|
||||
def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
|
||||
embeddings = OpenAIEmbeddings(chunk_size=2, check_embedding_ctx_length=False)
|
||||
texts = ["text1", "text2", "text3", "text4"]
|
||||
custom_chunk_size = 3
|
||||
|
Loading…
Reference in New Issue
Block a user