mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 01:37:59 +00:00
openai: embeddings: supported chunk_size when check_embedding_ctx_length is disabled (#23767)
Chunking of the input array controlled by `self.chunk_size` is being ignored when `self.check_embedding_ctx_length` is disabled. Effectively, the chunk size is assumed to be equal 1 in such a case. This is suprising. The PR takes into account `self.chunk_size` passed by the user. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
864020e592
commit
3e2cb4e8a4
@ -254,14 +254,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
retry_max_seconds: int = 20
|
||||
"""Max number of seconds to wait between retries"""
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
check_embedding_ctx_length: bool = True
|
||||
"""Whether to check the token length of inputs and automatically split inputs
|
||||
"""Whether to check the token length of inputs and automatically split inputs
|
||||
longer than embedding_ctx_length."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -558,7 +558,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
return [e if e is not None else await empty_embedding() for e in embeddings]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
self, texts: List[str], chunk_size: int | None = None
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
@ -570,10 +570,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: List[List[float]] = []
|
||||
for text in texts:
|
||||
response = self.client.create(input=text, **self._invocation_params)
|
||||
for i in range(0, len(texts), self.chunk_size):
|
||||
response = self.client.create(
|
||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
embeddings.extend(r["embedding"] for r in response["data"])
|
||||
@ -585,7 +588,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
self, texts: List[str], chunk_size: int | None = None
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
@ -597,11 +600,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
chunk_size_ = chunk_size or self.chunk_size
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: List[List[float]] = []
|
||||
for text in texts:
|
||||
for i in range(0, len(texts), chunk_size_):
|
||||
response = await self.async_client.create(
|
||||
input=text, **self._invocation_params
|
||||
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
|
@ -61,3 +61,12 @@ async def test_langchain_openai_embeddings_equivalent_to_raw_async() -> None:
|
||||
.embedding
|
||||
)
|
||||
assert np.isclose(lc_output, direct_output).all()
|
||||
|
||||
|
||||
def test_langchain_openai_embeddings_dimensions_large_num() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = [f"foo bar {i}" for i in range(2000)]
|
||||
embedding = OpenAIEmbeddings(model="text-embedding-3-small", dimensions=128)
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2000
|
||||
assert len(output[0]) == 128
|
||||
|
Loading…
Reference in New Issue
Block a user