From ed26149a29671f58971b5cc964a0b0908f428e21 Mon Sep 17 00:00:00 2001 From: YISH Date: Fri, 26 Apr 2024 00:45:52 +0800 Subject: [PATCH] openai[patch]: Allow disablling safe_len_embeddings(OpenAIEmbeddings) (#19743) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit OpenAI API compatible server may not support `safe_len_embedding`, use `disable_safe_len_embeddings=True` to disable it. --------- Co-authored-by: Bagatur --- .../langchain_openai/embeddings/base.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index b64388055bd..8ba966c47e2 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -129,6 +129,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): http_async_client: Union[Any, None] = None """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" + check_embedding_ctx_length: bool = True + """Whether to check the token length of inputs and automatically split inputs + longer than embedding_ctx_length.""" class Config: """Configuration for this pydantic object.""" @@ -511,6 +514,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + if not self.check_embedding_ctx_length: + embeddings: List[List[float]] = [] + for text in texts: + response = self.client.create( + input=text, + **self._invocation_params, + ) + if not isinstance(response, dict): + response = response.dict() + embeddings.extend(r["embedding"] for r in response["data"]) + return embeddings + # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) @@ -529,6 +544,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + if not self.check_embedding_ctx_length: + embeddings: List[List[float]] = [] + for text in texts: + response = await self.async_client.create( + input=text, + **self._invocation_params, + ) + if not isinstance(response, dict): + response = response.dict() + embeddings.extend(r["embedding"] for r in response["data"]) + return embeddings + # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment)