partners/openai + community: Async Azure AD token provider support for Azure OpenAI (#27488)

This PR introduces a new `azure_ad_async_token_provider` attribute to
the `AzureOpenAI` and `AzureChatOpenAI` classes in `partners/openai` and
`community` packages, given it's currently supported on `openai` package
as
[AsyncAzureADTokenProvider](https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L33)
type.

The reason for creating a new attribute is to avoid breaking changes.
Let's say you have an existing code that uses a `AzureOpenAI` or
`AzureChatOpenAI` instance to perform both sync and async operations.
The `azure_ad_token_provider` will work exactly as it is today, while
`azure_ad_async_token_provider` will override it for async requests.


If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
Fernando de Oliveira 2024-10-22 17:43:06 -04:00 committed by GitHub
parent 34684423bf
commit ab205e7389
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 96 additions and 11 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging import logging
import os import os
import warnings import warnings
from typing import Any, Callable, Dict, List, Union from typing import Any, Awaitable, Callable, Dict, List, Union
from langchain_core._api.deprecation import deprecated from langchain_core._api.deprecation import deprecated
from langchain_core.outputs import ChatResult from langchain_core.outputs import ChatResult
@ -90,7 +90,13 @@ class AzureChatOpenAI(ChatOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token. """A function that returns an Azure Active Directory token.
Will be invoked on every request. Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
""" """
model_version: str = "" model_version: str = ""
"""Legacy, for openai<1.0.0 support.""" """Legacy, for openai<1.0.0 support."""
@ -208,6 +214,12 @@ class AzureChatOpenAI(ChatOpenAI):
"http_client": values["http_client"], "http_client": values["http_client"],
} }
values["client"] = openai.AzureOpenAI(**client_params).chat.completions values["client"] = openai.AzureOpenAI(**client_params).chat.completions
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
if azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
values["async_client"] = openai.AsyncAzureOpenAI( values["async_client"] = openai.AsyncAzureOpenAI(
**client_params **client_params
).chat.completions ).chat.completions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import os import os
import warnings import warnings
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Awaitable, Callable, Dict, Optional, Union
from langchain_core._api.deprecation import deprecated from langchain_core._api.deprecation import deprecated
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
@ -49,7 +49,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
azure_ad_token_provider: Union[Callable[[], str], None] = None azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token. """A function that returns an Azure Active Directory token.
Will be invoked on every request. Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
""" """
openai_api_version: Optional[str] = Field(default=None, alias="api_version") openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
@ -162,6 +168,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"http_client": self.http_client, "http_client": self.http_client,
} }
self.client = openai.AzureOpenAI(**client_params).embeddings self.client = openai.AzureOpenAI(**client_params).embeddings
if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
else: else:
self.client = openai.Embedding self.client = openai.Embedding

View File

@ -8,6 +8,7 @@ from typing import (
AbstractSet, AbstractSet,
Any, Any,
AsyncIterator, AsyncIterator,
Awaitable,
Callable, Callable,
Collection, Collection,
Dict, Dict,
@ -804,7 +805,13 @@ class AzureOpenAI(BaseOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token. """A function that returns an Azure Active Directory token.
Will be invoked on every request. Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
""" """
openai_api_type: str = "" openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support.""" """Legacy, for openai<1.0.0 support."""
@ -922,6 +929,12 @@ class AzureOpenAI(BaseOpenAI):
"http_client": values["http_client"], "http_client": values["http_client"],
} }
values["client"] = openai.AzureOpenAI(**client_params).completions values["client"] = openai.AzureOpenAI(**client_params).completions
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
if azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
values["async_client"] = openai.AsyncAzureOpenAI( values["async_client"] = openai.AsyncAzureOpenAI(
**client_params **client_params
).completions ).completions

View File

@ -4,7 +4,18 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, TypeVar, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
import openai import openai
from langchain_core.language_models.chat_models import LangSmithParams from langchain_core.language_models.chat_models import LangSmithParams
@ -494,7 +505,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token. """A function that returns an Azure Active Directory token.
Will be invoked on every request. Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
""" """
model_version: str = "" model_version: str = ""
@ -633,6 +651,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
self.client = self.root_client.chat.completions self.client = self.root_client.chat.completions
if not self.async_client: if not self.async_client:
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.root_async_client = openai.AsyncAzureOpenAI( self.root_async_client = openai.AsyncAzureOpenAI(
**client_params, **client_params,
**async_specific, # type: ignore[arg-type] **async_specific, # type: ignore[arg-type]

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, Optional, Union from typing import Awaitable, Callable, Optional, Union
import openai import openai
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
@ -146,7 +146,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
azure_ad_token_provider: Union[Callable[[], str], None] = None azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token. """A function that returns an Azure Active Directory token.
Will be invoked on every request. Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
""" """
openai_api_type: Optional[str] = Field( openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default="azure") default_factory=from_env("OPENAI_API_TYPE", default="azure")
@ -203,6 +209,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
).embeddings ).embeddings
if not self.async_client: if not self.async_client:
async_specific: dict = {"http_client": self.http_async_client} async_specific: dict = {"http_client": self.http_async_client}
if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.async_client = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
**client_params, # type: ignore[arg-type] **client_params, # type: ignore[arg-type]
**async_specific, **async_specific,

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Mapping, Optional, Union from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
import openai import openai
from langchain_core.language_models import LangSmithParams from langchain_core.language_models import LangSmithParams
@ -73,7 +73,13 @@ class AzureOpenAI(BaseOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token. """A function that returns an Azure Active Directory token.
Will be invoked on every request. Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
""" """
openai_api_type: Optional[str] = Field( openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default="azure") default_factory=from_env("OPENAI_API_TYPE", default="azure")
@ -158,6 +164,12 @@ class AzureOpenAI(BaseOpenAI):
).completions ).completions
if not self.async_client: if not self.async_client:
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
self.async_client = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
**client_params, **client_params,
**async_specific, # type: ignore[arg-type] **async_specific, # type: ignore[arg-type]