mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
openai: embeddings: supported chunk_size when check_embedding_ctx_length is disabled (#23767)
Chunking of the input array controlled by `self.chunk_size` is being ignored when `self.check_embedding_ctx_length` is disabled. Effectively, the chunk size is assumed to be equal 1 in such a case. This is suprising. The PR takes into account `self.chunk_size` passed by the user. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
864020e592
commit
3e2cb4e8a4
@ -254,14 +254,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
retry_max_seconds: int = 20
|
retry_max_seconds: int = 20
|
||||||
"""Max number of seconds to wait between retries"""
|
"""Max number of seconds to wait between retries"""
|
||||||
http_client: Union[Any, None] = None
|
http_client: Union[Any, None] = None
|
||||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||||
http_async_client as well if you'd like a custom client for async invocations.
|
http_async_client as well if you'd like a custom client for async invocations.
|
||||||
"""
|
"""
|
||||||
http_async_client: Union[Any, None] = None
|
http_async_client: Union[Any, None] = None
|
||||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||||
http_client as well if you'd like a custom client for sync invocations."""
|
http_client as well if you'd like a custom client for sync invocations."""
|
||||||
check_embedding_ctx_length: bool = True
|
check_embedding_ctx_length: bool = True
|
||||||
"""Whether to check the token length of inputs and automatically split inputs
|
"""Whether to check the token length of inputs and automatically split inputs
|
||||||
longer than embedding_ctx_length."""
|
longer than embedding_ctx_length."""
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -558,7 +558,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
return [e if e is not None else await empty_embedding() for e in embeddings]
|
return [e if e is not None else await empty_embedding() for e in embeddings]
|
||||||
|
|
||||||
def embed_documents(
|
def embed_documents(
|
||||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
self, texts: List[str], chunk_size: int | None = None
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||||
|
|
||||||
@ -570,10 +570,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
|
chunk_size_ = chunk_size or self.chunk_size
|
||||||
if not self.check_embedding_ctx_length:
|
if not self.check_embedding_ctx_length:
|
||||||
embeddings: List[List[float]] = []
|
embeddings: List[List[float]] = []
|
||||||
for text in texts:
|
for i in range(0, len(texts), self.chunk_size):
|
||||||
response = self.client.create(input=text, **self._invocation_params)
|
response = self.client.create(
|
||||||
|
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||||
|
)
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.dict()
|
response = response.dict()
|
||||||
embeddings.extend(r["embedding"] for r in response["data"])
|
embeddings.extend(r["embedding"] for r in response["data"])
|
||||||
@ -585,7 +588,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
return self._get_len_safe_embeddings(texts, engine=engine)
|
||||||
|
|
||||||
async def aembed_documents(
|
async def aembed_documents(
|
||||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
self, texts: List[str], chunk_size: int | None = None
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||||
|
|
||||||
@ -597,11 +600,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
|
chunk_size_ = chunk_size or self.chunk_size
|
||||||
if not self.check_embedding_ctx_length:
|
if not self.check_embedding_ctx_length:
|
||||||
embeddings: List[List[float]] = []
|
embeddings: List[List[float]] = []
|
||||||
for text in texts:
|
for i in range(0, len(texts), chunk_size_):
|
||||||
response = await self.async_client.create(
|
response = await self.async_client.create(
|
||||||
input=text, **self._invocation_params
|
input=texts[i : i + chunk_size_], **self._invocation_params
|
||||||
)
|
)
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.dict()
|
response = response.dict()
|
||||||
|
@ -61,3 +61,12 @@ async def test_langchain_openai_embeddings_equivalent_to_raw_async() -> None:
|
|||||||
.embedding
|
.embedding
|
||||||
)
|
)
|
||||||
assert np.isclose(lc_output, direct_output).all()
|
assert np.isclose(lc_output, direct_output).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_langchain_openai_embeddings_dimensions_large_num() -> None:
|
||||||
|
"""Test openai embeddings."""
|
||||||
|
documents = [f"foo bar {i}" for i in range(2000)]
|
||||||
|
embedding = OpenAIEmbeddings(model="text-embedding-3-small", dimensions=128)
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 2000
|
||||||
|
assert len(output[0]) == 128
|
||||||
|
Loading…
Reference in New Issue
Block a user