mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
partners/openai + community: Async Azure AD token provider support for Azure OpenAI (#27488)
This PR introduces a new `azure_ad_async_token_provider` attribute to the `AzureOpenAI` and `AzureChatOpenAI` classes in `partners/openai` and `community` packages, given it's currently supported on `openai` package as [AsyncAzureADTokenProvider](https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L33) type. The reason for creating a new attribute is to avoid breaking changes. Let's say you have an existing code that uses a `AzureOpenAI` or `AzureChatOpenAI` instance to perform both sync and async operations. The `azure_ad_token_provider` will work exactly as it is today, while `azure_ad_async_token_provider` will override it for async requests. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
committed by
GitHub
parent
34684423bf
commit
ab205e7389
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
@@ -49,7 +49,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||
"""A function that returns an Azure Active Directory token.
|
||||
|
||||
Will be invoked on every request.
|
||||
Will be invoked on every sync request. For async requests,
|
||||
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||
"""
|
||||
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||
"""A function that returns an Azure Active Directory token.
|
||||
|
||||
Will be invoked on every async request.
|
||||
"""
|
||||
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||
@@ -162,6 +168,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
"http_client": self.http_client,
|
||||
}
|
||||
self.client = openai.AzureOpenAI(**client_params).embeddings
|
||||
|
||||
if self.azure_ad_async_token_provider:
|
||||
client_params["azure_ad_token_provider"] = (
|
||||
self.azure_ad_async_token_provider
|
||||
)
|
||||
|
||||
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
|
||||
else:
|
||||
self.client = openai.Embedding
|
||||
|
Reference in New Issue
Block a user