mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
openai: support runtime kwargs in embeddings (#31195)
This commit is contained in:
parent
4f41b54bcb
commit
0b8837a0cc
@ -447,7 +447,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
def _get_len_safe_embeddings(
|
||||
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
self,
|
||||
texts: list[str],
|
||||
*,
|
||||
engine: str,
|
||||
chunk_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate length-safe embeddings for a list of texts.
|
||||
@ -465,11 +470,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
List[List[float]]: A list of embeddings for each input text.
|
||||
"""
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||
batched_embeddings: list[list[float]] = []
|
||||
for i in _iter:
|
||||
response = self.client.create(
|
||||
input=tokens[i : i + _chunk_size], **self._invocation_params
|
||||
input=tokens[i : i + _chunk_size], **client_kwargs
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
@ -483,9 +489,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
def empty_embedding() -> list[float]:
|
||||
nonlocal _cached_empty_embedding
|
||||
if _cached_empty_embedding is None:
|
||||
average_embedded = self.client.create(
|
||||
input="", **self._invocation_params
|
||||
)
|
||||
average_embedded = self.client.create(input="", **client_kwargs)
|
||||
if not isinstance(average_embedded, dict):
|
||||
average_embedded = average_embedded.model_dump()
|
||||
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
|
||||
@ -496,7 +500,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# please refer to
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
async def _aget_len_safe_embeddings(
|
||||
self, texts: list[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
self,
|
||||
texts: list[str],
|
||||
*,
|
||||
engine: str,
|
||||
chunk_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Asynchronously generate length-safe embeddings for a list of texts.
|
||||
@ -515,11 +524,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
|
||||
batched_embeddings: list[list[float]] = []
|
||||
for i in range(0, len(tokens), _chunk_size):
|
||||
response = await self.async_client.create(
|
||||
input=tokens[i : i + _chunk_size], **self._invocation_params
|
||||
input=tokens[i : i + _chunk_size], **client_kwargs
|
||||
)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
@ -535,7 +545,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
nonlocal _cached_empty_embedding
|
||||
if _cached_empty_embedding is None:
|
||||
average_embedded = await self.async_client.create(
|
||||
input="", **self._invocation_params
|
||||
input="", **client_kwargs
|
||||
)
|
||||
if not isinstance(average_embedded, dict):
|
||||
average_embedded = average_embedded.model_dump()
|
||||
@ -545,7 +555,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
return [e if e is not None else await empty_embedding() for e in embeddings]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = None
|
||||
self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
@ -553,16 +563,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: list[list[float]] = []
|
||||
for i in range(0, len(texts), chunk_size_):
|
||||
response = self.client.create(
|
||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||
input=texts[i : i + chunk_size_], **client_kwargs
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
@ -573,11 +585,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
return self._get_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size
|
||||
texts, engine=engine, chunk_size=chunk_size, **kwargs
|
||||
)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: list[str], chunk_size: int | None = None
|
||||
self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
@ -585,16 +597,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
client_kwargs = {**self._invocation_params, **kwargs}
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: list[list[float]] = []
|
||||
for i in range(0, len(texts), chunk_size_):
|
||||
response = await self.async_client.create(
|
||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||
input=texts[i : i + chunk_size_], **client_kwargs
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
@ -605,28 +619,30 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
return await self._aget_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size
|
||||
texts, engine=engine, chunk_size=chunk_size, **kwargs
|
||||
)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
return self.embed_documents([text], **kwargs)[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
async def aembed_query(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional keyword arguments to pass to the embedding API.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
embeddings = await self.aembed_documents([text], **kwargs)
|
||||
return embeddings[0]
|
||||
|
@ -57,3 +57,42 @@ def test_embed_documents_with_custom_chunk_size_no_check_ctx_length() -> None:
|
||||
mock_create.assert_any_call(input=texts[3:4], **embeddings._invocation_params)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]
|
||||
|
||||
|
||||
def test_embed_with_kwargs() -> None:
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-3-small", check_embedding_ctx_length=False
|
||||
)
|
||||
texts = ["text1", "text2"]
|
||||
with patch.object(embeddings.client, "create") as mock_create:
|
||||
mock_create.side_effect = [
|
||||
{"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]}
|
||||
]
|
||||
|
||||
result = embeddings.embed_documents(texts, dimensions=3)
|
||||
mock_create.assert_any_call(
|
||||
input=texts, dimensions=3, **embeddings._invocation_params
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
async def test_embed_with_kwargs_async() -> None:
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-3-small",
|
||||
check_embedding_ctx_length=False,
|
||||
dimensions=4, # also check that runtime kwargs take precedence
|
||||
)
|
||||
texts = ["text1", "text2"]
|
||||
with patch.object(embeddings.async_client, "create") as mock_create:
|
||||
mock_create.side_effect = [
|
||||
{"data": [{"embedding": [0.1, 0.2, 0.3]}, {"embedding": [0.4, 0.5, 0.6]}]}
|
||||
]
|
||||
|
||||
result = await embeddings.aembed_documents(texts, dimensions=3)
|
||||
client_kwargs = embeddings._invocation_params.copy()
|
||||
assert client_kwargs["dimensions"] == 4
|
||||
client_kwargs["dimensions"] = 3
|
||||
mock_create.assert_any_call(input=texts, **client_kwargs)
|
||||
|
||||
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
Loading…
Reference in New Issue
Block a user