openai: support runtime kwargs in embeddings (#31195)

This commit is contained in:
ccurme 2025-05-14 09:14:40 -04:00 committed by GitHub
parent 4f41b54bcb
commit 0b8837a0cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 18 deletions

View File

@ -447,7 +447,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
self,
texts: list[str],
*,
engine: str,
chunk_size: Optional[int] = None,
**kwargs: Any,
) -> list[list[float]]:
"""
Generate length-safe embeddings for a list of texts.
@ -465,11 +470,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
List[List[float]]: A list of embeddings for each input text.
"""
_chunk_size = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: list[list[float]] = []
for i in _iter:
response = self.client.create(
input=tokens[i : i + _chunk_size], **self._invocation_params
input=tokens[i : i + _chunk_size], **client_kwargs
)
if not isinstance(response, dict):
response = response.model_dump()
@ -483,9 +489,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def empty_embedding() -> list[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(
input="", **self._invocation_params
)
average_embedded = self.client.create(input="", **client_kwargs)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
@ -496,7 +500,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
async def _aget_len_safe_embeddings(
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
self,
texts: list[str],
*,
engine: str,
chunk_size: Optional[int] = None,
**kwargs: Any,
) -> list[list[float]]:
"""
Asynchronously generate length-safe embeddings for a list of texts.
@ -515,11 +524,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
_chunk_size = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: list[list[float]] = []
for i in range(0, len(tokens), _chunk_size):
response = await self.async_client.create(
input=tokens[i : i + _chunk_size], **self._invocation_params
input=tokens[i : i + _chunk_size], **client_kwargs
)
if not isinstance(response, dict):
@ -535,7 +545,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = await self.async_client.create(
input="", **self._invocation_params
input="", **client_kwargs
)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
@ -545,7 +555,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return [e if e is not None else await empty_embedding() for e in embeddings]
def embed_documents(
self, texts: list[str], chunk_size: int | None = None
self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any
) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
@ -553,16 +563,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
kwargs: Additional keyword arguments to pass to the embedding API.
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
if not self.check_embedding_ctx_length:
embeddings: list[list[float]] = []
for i in range(0, len(texts), chunk_size_):
response = self.client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
input=texts[i : i + chunk_size_], **client_kwargs
)
if not isinstance(response, dict):
response = response.model_dump()
@ -573,11 +585,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
texts, engine=engine, chunk_size=chunk_size, **kwargs
)
async def aembed_documents(
self, texts: list[str], chunk_size: int | None = None
self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any
) -> list[list[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
@ -585,16 +597,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
kwargs: Additional keyword arguments to pass to the embedding API.
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
if not self.check_embedding_ctx_length:
embeddings: list[list[float]] = []
for i in range(0, len(texts), chunk_size_):
response = await self.async_client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
input=texts[i : i + chunk_size_], **client_kwargs
)
if not isinstance(response, dict):
response = response.model_dump()
@ -605,28 +619,30 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
texts, engine=engine, chunk_size=chunk_size, **kwargs
)
def embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Args:
text: The text to embed.
kwargs: Additional keyword arguments to pass to the embedding API.
Returns:
Embedding for the text.
"""
return self.embed_documents([text])[0]
return self.embed_documents([text], **kwargs)[0]
async def aembed_query(self, text: str) -> list[float]:
async def aembed_query(self, text: str, **kwargs: Any) -> list[float]:
"""Call out to OpenAI's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
kwargs: Additional keyword arguments to pass to the embedding API.
Returns:
Embedding for the text.
"""
embeddings = await self.aembed_documents([text])
embeddings = await self.aembed_documents([text], **kwargs)
return embeddings[0]

View File

@ -57,3 +57,42 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]
def test_embed_with_kwargs() -> None:
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small", check_embedding_ctx_length=False
)
texts = ["text1", "text2"]
with patch.object(embeddings.client, "create") as mock_create:
mock_create.side_effect = [
{"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]}
]
result = embeddings.embed_documents(texts, dimensions=3)
mock_create.assert_any_call(
input=texts, dimensions=3, **embeddings._invocation_params
)
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
async def test_embed_with_kwargs_async() -> None:
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small",
check_embedding_ctx_length=False,
dimensions=4, # also check that runtime kwargs take precedence
)
texts = ["text1", "text2"]
with patch.object(embeddings.async_client, "create") as mock_create:
mock_create.side_effect = [
{"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]}
]
result = await embeddings.aembed_documents(texts, dimensions=3)
client_kwargs = embeddings._invocation_params.copy()
assert client_kwargs["dimensions"] == 4
client_kwargs["dimensions"] = 3
mock_create.assert_any_call(input=texts, **client_kwargs)
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]