diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index b64388055bd..8ba966c47e2 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -129,6 +129,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): http_async_client: Union[Any, None] = None """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" + check_embedding_ctx_length: bool = True + """Whether to check the token length of inputs and automatically split inputs + longer than embedding_ctx_length.""" class Config: """Configuration for this pydantic object.""" @@ -511,6 +514,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + if not self.check_embedding_ctx_length: + embeddings: List[List[float]] = [] + for text in texts: + response = self.client.create( + input=text, + **self._invocation_params, + ) + if not isinstance(response, dict): + response = response.dict() + embeddings.extend(r["embedding"] for r in response["data"]) + return embeddings + # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment) @@ -529,6 +544,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + if not self.check_embedding_ctx_length: + embeddings: List[List[float]] = [] + for text in texts: + response = await self.async_client.create( + input=text, + **self._invocation_params, + ) + if not isinstance(response, dict): + response = response.dict() + embeddings.extend(r["embedding"] for r in response["data"]) + return embeddings + # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. engine = cast(str, self.deployment)