mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
openai[patch]: fix azure embedding length check (#19870)
This commit is contained in:
parent
d62e84c4f5
commit
be92cf57ca
@ -1,4 +1,5 @@
|
||||
"""Azure OpenAI embeddings wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@ -57,6 +58,8 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||
validate_base_url: bool = True
|
||||
chunk_size: int = 2048
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -102,7 +105,11 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
# Azure OpenAI embedding models allow a maximum of 2048 texts
|
||||
# at a time in each batch
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#best-practices
|
||||
values["chunk_size"] = min(values["chunk_size"], 2048)
|
||||
if values["chunk_size"] > 2048:
|
||||
raise ValueError(
|
||||
"Azure OpenAI embeddings only allow a maximum of 2048 texts at a time "
|
||||
"in each batch."
|
||||
)
|
||||
# For backwards compatibility. Before openai v1, no distinction was made
|
||||
# between azure_endpoint and base_url (openai_api_base).
|
||||
openai_api_base = values["openai_api_base"]
|
||||
@ -126,12 +133,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
"api_version": values["openai_api_version"],
|
||||
"azure_endpoint": values["azure_endpoint"],
|
||||
"azure_deployment": values["deployment"],
|
||||
"api_key": values["openai_api_key"].get_secret_value()
|
||||
if values["openai_api_key"]
|
||||
else None,
|
||||
"azure_ad_token": values["azure_ad_token"].get_secret_value()
|
||||
if values["azure_ad_token"]
|
||||
else None,
|
||||
"api_key": (
|
||||
values["openai_api_key"].get_secret_value()
|
||||
if values["openai_api_key"]
|
||||
else None
|
||||
),
|
||||
"azure_ad_token": (
|
||||
values["azure_ad_token"].get_secret_value()
|
||||
if values["azure_ad_token"]
|
||||
else None
|
||||
),
|
||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
|
@ -60,8 +60,8 @@ def test_azure_openai_embedding_documents_chunk_size() -> None:
|
||||
embedding = _get_embeddings()
|
||||
embedding.embedding_ctx_length = 8191
|
||||
output = embedding.embed_documents(documents)
|
||||
# Max 16 chunks per batch on Azure OpenAI embeddings
|
||||
assert embedding.chunk_size == 16
|
||||
# Max 2048 chunks per batch on Azure OpenAI embeddings
|
||||
assert embedding.chunk_size == 2048
|
||||
assert len(output) == 20
|
||||
assert all([len(out) == 1536 for out in output])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user