mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
openai[patch]: fix azure embedding length check (#19870)
This commit is contained in:
parent
d62e84c4f5
commit
be92cf57ca
@ -1,4 +1,5 @@
|
|||||||
"""Azure OpenAI embeddings wrapper."""
|
"""Azure OpenAI embeddings wrapper."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -57,6 +58,8 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
||||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||||
validate_base_url: bool = True
|
validate_base_url: bool = True
|
||||||
|
chunk_size: int = 2048
|
||||||
|
"""Maximum number of texts to embed in each batch"""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -102,7 +105,11 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
# Azure OpenAI embedding models allow a maximum of 2048 texts
|
# Azure OpenAI embedding models allow a maximum of 2048 texts
|
||||||
# at a time in each batch
|
# at a time in each batch
|
||||||
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#best-practices
|
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#best-practices
|
||||||
values["chunk_size"] = min(values["chunk_size"], 2048)
|
if values["chunk_size"] > 2048:
|
||||||
|
raise ValueError(
|
||||||
|
"Azure OpenAI embeddings only allow a maximum of 2048 texts at a time "
|
||||||
|
"in each batch."
|
||||||
|
)
|
||||||
# For backwards compatibility. Before openai v1, no distinction was made
|
# For backwards compatibility. Before openai v1, no distinction was made
|
||||||
# between azure_endpoint and base_url (openai_api_base).
|
# between azure_endpoint and base_url (openai_api_base).
|
||||||
openai_api_base = values["openai_api_base"]
|
openai_api_base = values["openai_api_base"]
|
||||||
@ -126,12 +133,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
"api_version": values["openai_api_version"],
|
"api_version": values["openai_api_version"],
|
||||||
"azure_endpoint": values["azure_endpoint"],
|
"azure_endpoint": values["azure_endpoint"],
|
||||||
"azure_deployment": values["deployment"],
|
"azure_deployment": values["deployment"],
|
||||||
"api_key": values["openai_api_key"].get_secret_value()
|
"api_key": (
|
||||||
if values["openai_api_key"]
|
values["openai_api_key"].get_secret_value()
|
||||||
else None,
|
if values["openai_api_key"]
|
||||||
"azure_ad_token": values["azure_ad_token"].get_secret_value()
|
else None
|
||||||
if values["azure_ad_token"]
|
),
|
||||||
else None,
|
"azure_ad_token": (
|
||||||
|
values["azure_ad_token"].get_secret_value()
|
||||||
|
if values["azure_ad_token"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||||
"organization": values["openai_organization"],
|
"organization": values["openai_organization"],
|
||||||
"base_url": values["openai_api_base"],
|
"base_url": values["openai_api_base"],
|
||||||
|
@ -60,8 +60,8 @@ def test_azure_openai_embedding_documents_chunk_size() -> None:
|
|||||||
embedding = _get_embeddings()
|
embedding = _get_embeddings()
|
||||||
embedding.embedding_ctx_length = 8191
|
embedding.embedding_ctx_length = 8191
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
# Max 16 chunks per batch on Azure OpenAI embeddings
|
# Max 2048 chunks per batch on Azure OpenAI embeddings
|
||||||
assert embedding.chunk_size == 16
|
assert embedding.chunk_size == 2048
|
||||||
assert len(output) == 20
|
assert len(output) == 20
|
||||||
assert all([len(out) == 1536 for out in output])
|
assert all([len(out) == 1536 for out in output])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user