From 0b8837a0cc0b906837092d7df59d6fa13756416d Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 14 May 2025 09:14:40 -0400 Subject: [PATCH] openai: support runtime kwargs in embeddings (#31195) --- .../langchain_openai/embeddings/base.py | 52 ++++++++++++------- .../tests/unit_tests/embeddings/test_base.py | 39 ++++++++++++++ 2 files changed, 73 insertions(+), 18 deletions(-) diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 3a35eebab9d..e9e688983fd 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -447,7 +447,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb def _get_len_safe_embeddings( - self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None + self, + texts: list[str], + *, + engine: str, + chunk_size: Optional[int] = None, + **kwargs: Any, ) -> list[list[float]]: """ Generate length-safe embeddings for a list of texts. @@ -465,11 +470,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings): List[List[float]]: A list of embeddings for each input text. """ _chunk_size = chunk_size or self.chunk_size + client_kwargs = {**self._invocation_params, **kwargs} _iter, tokens, indices = self._tokenize(texts, _chunk_size) batched_embeddings: list[list[float]] = [] for i in _iter: response = self.client.create( - input=tokens[i : i + _chunk_size], **self._invocation_params + input=tokens[i : i + _chunk_size], **client_kwargs ) if not isinstance(response, dict): response = response.model_dump() @@ -483,9 +489,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def empty_embedding() -> list[float]: nonlocal _cached_empty_embedding if _cached_empty_embedding is None: - average_embedded = self.client.create( - input="", **self._invocation_params - ) + average_embedded = self.client.create(input="", **client_kwargs) if not isinstance(average_embedded, dict): average_embedded = average_embedded.model_dump() _cached_empty_embedding = average_embedded["data"][0]["embedding"] @@ -496,7 +500,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb async def _aget_len_safe_embeddings( - self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None + self, + texts: list[str], + *, + engine: str, + chunk_size: Optional[int] = None, + **kwargs: Any, ) -> list[list[float]]: """ Asynchronously generate length-safe embeddings for a list of texts. @@ -515,11 +524,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ _chunk_size = chunk_size or self.chunk_size + client_kwargs = {**self._invocation_params, **kwargs} _iter, tokens, indices = self._tokenize(texts, _chunk_size) batched_embeddings: list[list[float]] = [] for i in range(0, len(tokens), _chunk_size): response = await self.async_client.create( - input=tokens[i : i + _chunk_size], **self._invocation_params + input=tokens[i : i + _chunk_size], **client_kwargs ) if not isinstance(response, dict): @@ -535,7 +545,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): nonlocal _cached_empty_embedding if _cached_empty_embedding is None: average_embedded = await self.async_client.create( - input="", **self._invocation_params + input="", **client_kwargs ) if not isinstance(average_embedded, dict): average_embedded = average_embedded.model_dump() @@ -545,7 +555,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return [e if e is not None else await empty_embedding() for e in embeddings] def embed_documents( - self, texts: list[str], chunk_size: int | None = None + self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any ) -> list[list[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. @@ -553,16 +563,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings): texts: The list of texts to embed. chunk_size: The chunk size of embeddings. If None, will use the chunk size specified by the class. + kwargs: Additional keyword arguments to pass to the embedding API. Returns: List of embeddings, one for each text. """ chunk_size_ = chunk_size or self.chunk_size + client_kwargs = {**self._invocation_params, **kwargs} if not self.check_embedding_ctx_length: embeddings: list[list[float]] = [] for i in range(0, len(texts), chunk_size_): response = self.client.create( - input=texts[i : i + chunk_size_], **self._invocation_params + input=texts[i : i + chunk_size_], **client_kwargs ) if not isinstance(response, dict): response = response.model_dump() @@ -573,11 +585,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) return self._get_len_safe_embeddings( - texts, engine=engine, chunk_size=chunk_size + texts, engine=engine, chunk_size=chunk_size, **kwargs ) async def aembed_documents( - self, texts: list[str], chunk_size: int | None = None + self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any ) -> list[list[float]]: """Call out to OpenAI's embedding endpoint async for embedding search docs. @@ -585,16 +597,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings): texts: The list of texts to embed. chunk_size: The chunk size of embeddings. If None, will use the chunk size specified by the class. + kwargs: Additional keyword arguments to pass to the embedding API. Returns: List of embeddings, one for each text. """ chunk_size_ = chunk_size or self.chunk_size + client_kwargs = {**self._invocation_params, **kwargs} if not self.check_embedding_ctx_length: embeddings: list[list[float]] = [] for i in range(0, len(texts), chunk_size_): response = await self.async_client.create( - input=texts[i : i + chunk_size_], **self._invocation_params + input=texts[i : i + chunk_size_], **client_kwargs ) if not isinstance(response, dict): response = response.model_dump() @@ -605,28 +619,30 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) return await self._aget_len_safe_embeddings( - texts, engine=engine, chunk_size=chunk_size + texts, engine=engine, chunk_size=chunk_size, **kwargs ) - def embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str, **kwargs: Any) -> list[float]: """Call out to OpenAI's embedding endpoint for embedding query text. Args: text: The text to embed. + kwargs: Additional keyword arguments to pass to the embedding API. Returns: Embedding for the text. """ - return self.embed_documents([text])[0] + return self.embed_documents([text], **kwargs)[0] - async def aembed_query(self, text: str) -> list[float]: + async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: """Call out to OpenAI's embedding endpoint async for embedding query text. Args: text: The text to embed. + kwargs: Additional keyword arguments to pass to the embedding API. Returns: Embedding for the text. """ - embeddings = await self.aembed_documents([text]) + embeddings = await self.aembed_documents([text], **kwargs) return embeddings[0] diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py index 7714b89b12a..f87dc181a37 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py @@ -57,3 +57,42 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None: mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params) assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]] + + +def test_embed_with_kwargs() -> None: + embeddings = OpenAIEmbeddings( + model="text-embedding-3-small", check_embedding_ctx_length=False + ) + texts = ["text1", "text2"] + with patch.object(embeddings.client, "create") as mock_create: + mock_create.side_effect = [ + {"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]} + ] + + result = embeddings.embed_documents(texts, dimensions=3) + mock_create.assert_any_call( + input=texts, dimensions=3, **embeddings._invocation_params + ) + + assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + +async def test_embed_with_kwargs_async() -> None: + embeddings = OpenAIEmbeddings( + model="text-embedding-3-small", + check_embedding_ctx_length=False, + dimensions=4, # also check that runtime kwargs take precedence + ) + texts = ["text1", "text2"] + with patch.object(embeddings.async_client, "create") as mock_create: + mock_create.side_effect = [ + {"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]} + ] + + result = await embeddings.aembed_documents(texts, dimensions=3) + client_kwargs = embeddings._invocation_params.copy() + assert client_kwargs["dimensions"] == 4 + client_kwargs["dimensions"] = 3 + mock_create.assert_any_call(input=texts, **client_kwargs) + + assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]