mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
openai: support runtime kwargs in embeddings (#31195)
This commit is contained in:
parent
4f41b54bcb
commit
0b8837a0cc
@ -447,7 +447,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# please refer to
|
# please refer to
|
||||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||||
def _get_len_safe_embeddings(
|
def _get_len_safe_embeddings(
|
||||||
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
engine: str,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Generate length-safe embeddings for a list of texts.
|
Generate length-safe embeddings for a list of texts.
|
||||||
@ -465,11 +470,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
List[List[float]]: A list of embeddings for each input text.
|
List[List[float]]: A list of embeddings for each input text.
|
||||||
"""
|
"""
|
||||||
_chunk_size = chunk_size or self.chunk_size
|
_chunk_size = chunk_size or self.chunk_size
|
||||||
|
client_kwargs = {**self._invocation_params, **kwargs}
|
||||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||||
batched_embeddings: list[list[float]] = []
|
batched_embeddings: list[list[float]] = []
|
||||||
for i in _iter:
|
for i in _iter:
|
||||||
response = self.client.create(
|
response = self.client.create(
|
||||||
input=tokens[i : i + _chunk_size], **self._invocation_params
|
input=tokens[i : i + _chunk_size], **client_kwargs
|
||||||
)
|
)
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.model_dump()
|
response = response.model_dump()
|
||||||
@ -483,9 +489,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
def empty_embedding() -> list[float]:
|
def empty_embedding() -> list[float]:
|
||||||
nonlocal _cached_empty_embedding
|
nonlocal _cached_empty_embedding
|
||||||
if _cached_empty_embedding is None:
|
if _cached_empty_embedding is None:
|
||||||
average_embedded = self.client.create(
|
average_embedded = self.client.create(input="", **client_kwargs)
|
||||||
input="", **self._invocation_params
|
|
||||||
)
|
|
||||||
if not isinstance(average_embedded, dict):
|
if not isinstance(average_embedded, dict):
|
||||||
average_embedded = average_embedded.model_dump()
|
average_embedded = average_embedded.model_dump()
|
||||||
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
|
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
|
||||||
@ -496,7 +500,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# please refer to
|
# please refer to
|
||||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||||
async def _aget_len_safe_embeddings(
|
async def _aget_len_safe_embeddings(
|
||||||
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
engine: str,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Asynchronously generate length-safe embeddings for a list of texts.
|
Asynchronously generate length-safe embeddings for a list of texts.
|
||||||
@ -515,11 +524,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_chunk_size = chunk_size or self.chunk_size
|
_chunk_size = chunk_size or self.chunk_size
|
||||||
|
client_kwargs = {**self._invocation_params, **kwargs}
|
||||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||||
batched_embeddings: list[list[float]] = []
|
batched_embeddings: list[list[float]] = []
|
||||||
for i in range(0, len(tokens), _chunk_size):
|
for i in range(0, len(tokens), _chunk_size):
|
||||||
response = await self.async_client.create(
|
response = await self.async_client.create(
|
||||||
input=tokens[i : i + _chunk_size], **self._invocation_params
|
input=tokens[i : i + _chunk_size], **client_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
@ -535,7 +545,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
nonlocal _cached_empty_embedding
|
nonlocal _cached_empty_embedding
|
||||||
if _cached_empty_embedding is None:
|
if _cached_empty_embedding is None:
|
||||||
average_embedded = await self.async_client.create(
|
average_embedded = await self.async_client.create(
|
||||||
input="", **self._invocation_params
|
input="", **client_kwargs
|
||||||
)
|
)
|
||||||
if not isinstance(average_embedded, dict):
|
if not isinstance(average_embedded, dict):
|
||||||
average_embedded = average_embedded.model_dump()
|
average_embedded = average_embedded.model_dump()
|
||||||
@ -545,7 +555,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
return [e if e is not None else await empty_embedding() for e in embeddings]
|
return [e if e is not None else await empty_embedding() for e in embeddings]
|
||||||
|
|
||||||
def embed_documents(
|
def embed_documents(
|
||||||
self, texts: list[str], chunk_size: int | None = None
|
self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||||
|
|
||||||
@ -553,16 +563,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed.
|
||||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||||
specified by the class.
|
specified by the class.
|
||||||
|
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
chunk_size_ = chunk_size or self.chunk_size
|
chunk_size_ = chunk_size or self.chunk_size
|
||||||
|
client_kwargs = {**self._invocation_params, **kwargs}
|
||||||
if not self.check_embedding_ctx_length:
|
if not self.check_embedding_ctx_length:
|
||||||
embeddings: list[list[float]] = []
|
embeddings: list[list[float]] = []
|
||||||
for i in range(0, len(texts), chunk_size_):
|
for i in range(0, len(texts), chunk_size_):
|
||||||
response = self.client.create(
|
response = self.client.create(
|
||||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
input=texts[i : i + chunk_size_], **client_kwargs
|
||||||
)
|
)
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.model_dump()
|
response = response.model_dump()
|
||||||
@ -573,11 +585,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# than the maximum context and use length-safe embedding function.
|
# than the maximum context and use length-safe embedding function.
|
||||||
engine = cast(str, self.deployment)
|
engine = cast(str, self.deployment)
|
||||||
return self._get_len_safe_embeddings(
|
return self._get_len_safe_embeddings(
|
||||||
texts, engine=engine, chunk_size=chunk_size
|
texts, engine=engine, chunk_size=chunk_size, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aembed_documents(
|
async def aembed_documents(
|
||||||
self, texts: list[str], chunk_size: int | None = None
|
self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||||
|
|
||||||
@ -585,16 +597,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed.
|
||||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||||
specified by the class.
|
specified by the class.
|
||||||
|
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
chunk_size_ = chunk_size or self.chunk_size
|
chunk_size_ = chunk_size or self.chunk_size
|
||||||
|
client_kwargs = {**self._invocation_params, **kwargs}
|
||||||
if not self.check_embedding_ctx_length:
|
if not self.check_embedding_ctx_length:
|
||||||
embeddings: list[list[float]] = []
|
embeddings: list[list[float]] = []
|
||||||
for i in range(0, len(texts), chunk_size_):
|
for i in range(0, len(texts), chunk_size_):
|
||||||
response = await self.async_client.create(
|
response = await self.async_client.create(
|
||||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
input=texts[i : i + chunk_size_], **client_kwargs
|
||||||
)
|
)
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.model_dump()
|
response = response.model_dump()
|
||||||
@ -605,28 +619,30 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# than the maximum context and use length-safe embedding function.
|
# than the maximum context and use length-safe embedding function.
|
||||||
engine = cast(str, self.deployment)
|
engine = cast(str, self.deployment)
|
||||||
return await self._aget_len_safe_embeddings(
|
return await self._aget_len_safe_embeddings(
|
||||||
texts, engine=engine, chunk_size=chunk_size
|
texts, engine=engine, chunk_size=chunk_size, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> list[float]:
|
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to embed.
|
text: The text to embed.
|
||||||
|
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
return self.embed_documents([text])[0]
|
return self.embed_documents([text], **kwargs)[0]
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> list[float]:
|
async def aembed_query(self, text: str, **kwargs: Any) -> list[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to embed.
|
text: The text to embed.
|
||||||
|
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
embeddings = await self.aembed_documents([text])
|
embeddings = await self.aembed_documents([text], **kwargs)
|
||||||
return embeddings[0]
|
return embeddings[0]
|
||||||
|
@ -57,3 +57,42 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
|
|||||||
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)
|
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)
|
||||||
|
|
||||||
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]
|
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_with_kwargs() -> None:
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
|
model="text-embedding-3-small", check_embedding_ctx_length=False
|
||||||
|
)
|
||||||
|
texts = ["text1", "text2"]
|
||||||
|
with patch.object(embeddings.client, "create") as mock_create:
|
||||||
|
mock_create.side_effect = [
|
||||||
|
{"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = embeddings.embed_documents(texts, dimensions=3)
|
||||||
|
mock_create.assert_any_call(
|
||||||
|
input=texts, dimensions=3, **embeddings._invocation_params
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_embed_with_kwargs_async() -> None:
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
check_embedding_ctx_length=False,
|
||||||
|
dimensions=4, # also check that runtime kwargs take precedence
|
||||||
|
)
|
||||||
|
texts = ["text1", "text2"]
|
||||||
|
with patch.object(embeddings.async_client, "create") as mock_create:
|
||||||
|
mock_create.side_effect = [
|
||||||
|
{"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await embeddings.aembed_documents(texts, dimensions=3)
|
||||||
|
client_kwargs = embeddings._invocation_params.copy()
|
||||||
|
assert client_kwargs["dimensions"] == 4
|
||||||
|
client_kwargs["dimensions"] = 3
|
||||||
|
mock_create.assert_any_call(input=texts, **client_kwargs)
|
||||||
|
|
||||||
|
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
Loading…
Reference in New Issue
Block a user