From be92cf57ca50f34b4900d37337c5015a30844e06 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 1 Apr 2024 10:26:15 -0700 Subject: [PATCH] openai[patch]: fix azure embedding length check (#19870) --- .../langchain_openai/embeddings/azure.py | 25 +++++++++++++------ .../embeddings/test_azure.py | 4 +-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index 0daabd40f92..f162c18ecd0 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -1,4 +1,5 @@ """Azure OpenAI embeddings wrapper.""" + from __future__ import annotations import os @@ -57,6 +58,8 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): openai_api_version: Optional[str] = Field(default=None, alias="api_version") """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" validate_base_url: bool = True + chunk_size: int = 2048 + """Maximum number of texts to embed in each batch""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -102,7 +105,11 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # Azure OpenAI embedding models allow a maximum of 2048 texts # at a time in each batch # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#best-practices - values["chunk_size"] = min(values["chunk_size"], 2048) + if values["chunk_size"] > 2048: + raise ValueError( + "Azure OpenAI embeddings only allow a maximum of 2048 texts at a time " + "in each batch." + ) # For backwards compatibility. Before openai v1, no distinction was made # between azure_endpoint and base_url (openai_api_base). openai_api_base = values["openai_api_base"] @@ -126,12 +133,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): "api_version": values["openai_api_version"], "azure_endpoint": values["azure_endpoint"], "azure_deployment": values["deployment"], - "api_key": values["openai_api_key"].get_secret_value() - if values["openai_api_key"] - else None, - "azure_ad_token": values["azure_ad_token"].get_secret_value() - if values["azure_ad_token"] - else None, + "api_key": ( + values["openai_api_key"].get_secret_value() + if values["openai_api_key"] + else None + ), + "azure_ad_token": ( + values["azure_ad_token"].get_secret_value() + if values["azure_ad_token"] + else None + ), "azure_ad_token_provider": values["azure_ad_token_provider"], "organization": values["openai_organization"], "base_url": values["openai_api_base"], diff --git a/libs/partners/openai/tests/integration_tests/embeddings/test_azure.py b/libs/partners/openai/tests/integration_tests/embeddings/test_azure.py index 3100f7fe08c..6f697c9a3b0 100644 --- a/libs/partners/openai/tests/integration_tests/embeddings/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/embeddings/test_azure.py @@ -60,8 +60,8 @@ def test_azure_openai_embedding_documents_chunk_size() -> None: embedding = _get_embeddings() embedding.embedding_ctx_length = 8191 output = embedding.embed_documents(documents) - # Max 16 chunks per batch on Azure OpenAI embeddings - assert embedding.chunk_size == 16 + # Max 2048 chunks per batch on Azure OpenAI embeddings + assert embedding.chunk_size == 2048 assert len(output) == 20 assert all([len(out) == 1536 for out in output])