openai: embeddings: supported chunk_size when check_embedding_ctx_length is disabled (#23767)

Chunking of the input array controlled by `self.chunk_size` is being
ignored when `self.check_embedding_ctx_length` is disabled. Effectively,
the chunk size is assumed to be equal 1 in such a case. This is
suprising.

The PR takes into account `self.chunk_size` passed by the user.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Anton Dubovik
2024-09-21 01:58:45 +02:00
committed by GitHub
parent 864020e592
commit 3e2cb4e8a4
2 changed files with 22 additions and 9 deletions

View File

@@ -254,14 +254,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20
"""Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""
model_config = ConfigDict(
@@ -558,7 +558,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return [e if e is not None else await empty_embedding() for e in embeddings]
def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
@@ -570,10 +570,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
response = self.client.create(input=text, **self._invocation_params)
for i in range(0, len(texts), self.chunk_size):
response = self.client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"])
@@ -585,7 +588,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self._get_len_safe_embeddings(texts, engine=engine)
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
@@ -597,11 +600,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
for i in range(0, len(texts), chunk_size_):
response = await self.async_client.create(
input=text, **self._invocation_params
input=texts[i : i + chunk_size_], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()