From ab205e738982fe9b4f756476c57ca1c1f0edd134 Mon Sep 17 00:00:00 2001 From: Fernando de Oliveira <5161098+fedeoliv@users.noreply.github.com> Date: Tue, 22 Oct 2024 17:43:06 -0400 Subject: [PATCH] partners/openai + community: Async Azure AD token provider support for Azure OpenAI (#27488) This PR introduces a new `azure_ad_async_token_provider` attribute to the `AzureOpenAI` and `AzureChatOpenAI` classes in `partners/openai` and `community` packages, given it's currently supported on `openai` package as [AsyncAzureADTokenProvider](https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L33) type. The reason for creating a new attribute is to avoid breaking changes. Let's say you have an existing code that uses a `AzureOpenAI` or `AzureChatOpenAI` instance to perform both sync and async operations. The `azure_ad_token_provider` will work exactly as it is today, while `azure_ad_async_token_provider` will override it for async requests. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --- .../chat_models/azure_openai.py | 16 +++++++++-- .../embeddings/azure_openai.py | 16 +++++++++-- .../langchain_community/llms/openai.py | 15 +++++++++- .../langchain_openai/chat_models/azure.py | 28 +++++++++++++++++-- .../langchain_openai/embeddings/azure.py | 16 +++++++++-- .../openai/langchain_openai/llms/azure.py | 16 +++++++++-- 6 files changed, 96 insertions(+), 11 deletions(-) diff --git a/libs/community/langchain_community/chat_models/azure_openai.py b/libs/community/langchain_community/chat_models/azure_openai.py index 83bc551d86a..182e10f6f3b 100644 --- a/libs/community/langchain_community/chat_models/azure_openai.py +++ b/libs/community/langchain_community/chat_models/azure_openai.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import os import warnings -from typing import Any, Callable, Dict, List, Union +from typing import Any, Awaitable, Callable, Dict, List, Union from langchain_core._api.deprecation import deprecated from langchain_core.outputs import ChatResult @@ -90,7 +90,13 @@ class AzureChatOpenAI(ChatOpenAI): azure_ad_token_provider: Union[Callable[[], str], None] = None """A function that returns an Azure Active Directory token. - Will be invoked on every request. + Will be invoked on every sync request. For async requests, + will be invoked if `azure_ad_async_token_provider` is not provided. + """ + azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + """A function that returns an Azure Active Directory token. + + Will be invoked on every async request. """ model_version: str = "" """Legacy, for openai<1.0.0 support.""" @@ -208,6 +214,12 @@ class AzureChatOpenAI(ChatOpenAI): "http_client": values["http_client"], } values["client"] = openai.AzureOpenAI(**client_params).chat.completions + + azure_ad_async_token_provider = values["azure_ad_async_token_provider"] + + if azure_ad_async_token_provider: + client_params["azure_ad_token_provider"] = azure_ad_async_token_provider + values["async_client"] = openai.AsyncAzureOpenAI( **client_params ).chat.completions diff --git a/libs/community/langchain_community/embeddings/azure_openai.py b/libs/community/langchain_community/embeddings/azure_openai.py index f760d796feb..b65422320a5 100644 --- a/libs/community/langchain_community/embeddings/azure_openai.py +++ b/libs/community/langchain_community/embeddings/azure_openai.py @@ -4,7 +4,7 @@ from __future__ import annotations import os import warnings -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Union from langchain_core._api.deprecation import deprecated from langchain_core.utils import get_from_dict_or_env @@ -49,7 +49,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): azure_ad_token_provider: Union[Callable[[], str], None] = None """A function that returns an Azure Active Directory token. - Will be invoked on every request. + Will be invoked on every sync request. For async requests, + will be invoked if `azure_ad_async_token_provider` is not provided. + """ + azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + """A function that returns an Azure Active Directory token. + + Will be invoked on every async request. """ openai_api_version: Optional[str] = Field(default=None, alias="api_version") """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" @@ -162,6 +168,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): "http_client": self.http_client, } self.client = openai.AzureOpenAI(**client_params).embeddings + + if self.azure_ad_async_token_provider: + client_params["azure_ad_token_provider"] = ( + self.azure_ad_async_token_provider + ) + self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings else: self.client = openai.Embedding diff --git a/libs/community/langchain_community/llms/openai.py b/libs/community/langchain_community/llms/openai.py index cc33d07ac4e..8a2b5d2e0a4 100644 --- a/libs/community/langchain_community/llms/openai.py +++ b/libs/community/langchain_community/llms/openai.py @@ -8,6 +8,7 @@ from typing import ( AbstractSet, Any, AsyncIterator, + Awaitable, Callable, Collection, Dict, @@ -804,7 +805,13 @@ class AzureOpenAI(BaseOpenAI): azure_ad_token_provider: Union[Callable[[], str], None] = None """A function that returns an Azure Active Directory token. - Will be invoked on every request. + Will be invoked on every sync request. For async requests, + will be invoked if `azure_ad_async_token_provider` is not provided. + """ + azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + """A function that returns an Azure Active Directory token. + + Will be invoked on every async request. """ openai_api_type: str = "" """Legacy, for openai<1.0.0 support.""" @@ -922,6 +929,12 @@ class AzureOpenAI(BaseOpenAI): "http_client": values["http_client"], } values["client"] = openai.AzureOpenAI(**client_params).completions + + azure_ad_async_token_provider = values["azure_ad_async_token_provider"] + + if azure_ad_async_token_provider: + client_params["azure_ad_token_provider"] = azure_ad_async_token_provider + values["async_client"] = openai.AsyncAzureOpenAI( **client_params ).completions diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 0bfd220888e..2e1e5f8abfe 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -4,7 +4,18 @@ from __future__ import annotations import logging import os -from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, TypeVar, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Type, + TypedDict, + TypeVar, + Union, +) import openai from langchain_core.language_models.chat_models import LangSmithParams @@ -494,7 +505,14 @@ class AzureChatOpenAI(BaseChatOpenAI): azure_ad_token_provider: Union[Callable[[], str], None] = None """A function that returns an Azure Active Directory token. - Will be invoked on every request. + Will be invoked on every sync request. For async requests, + will be invoked if `azure_ad_async_token_provider` is not provided. + """ + + azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + """A function that returns an Azure Active Directory token. + + Will be invoked on every async request. """ model_version: str = "" @@ -633,6 +651,12 @@ class AzureChatOpenAI(BaseChatOpenAI): self.client = self.root_client.chat.completions if not self.async_client: async_specific = {"http_client": self.http_async_client} + + if self.azure_ad_async_token_provider: + client_params["azure_ad_token_provider"] = ( + self.azure_ad_async_token_provider + ) + self.root_async_client = openai.AsyncAzureOpenAI( **client_params, **async_specific, # type: ignore[arg-type] diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index 06349e36a51..0341c908fcd 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Callable, Optional, Union +from typing import Awaitable, Callable, Optional, Union import openai from langchain_core.utils import from_env, secret_from_env @@ -146,7 +146,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): azure_ad_token_provider: Union[Callable[[], str], None] = None """A function that returns an Azure Active Directory token. - Will be invoked on every request. + Will be invoked on every sync request. For async requests, + will be invoked if `azure_ad_async_token_provider` is not provided. + """ + azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + """A function that returns an Azure Active Directory token. + + Will be invoked on every async request. """ openai_api_type: Optional[str] = Field( default_factory=from_env("OPENAI_API_TYPE", default="azure") @@ -203,6 +209,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): ).embeddings if not self.async_client: async_specific: dict = {"http_client": self.http_async_client} + + if self.azure_ad_async_token_provider: + client_params["azure_ad_token_provider"] = ( + self.azure_ad_async_token_provider + ) + self.async_client = openai.AsyncAzureOpenAI( **client_params, # type: ignore[arg-type] **async_specific, diff --git a/libs/partners/openai/langchain_openai/llms/azure.py b/libs/partners/openai/langchain_openai/llms/azure.py index 90c20c9d4d7..731b3e567f0 100644 --- a/libs/partners/openai/langchain_openai/llms/azure.py +++ b/libs/partners/openai/langchain_openai/llms/azure.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union import openai from langchain_core.language_models import LangSmithParams @@ -73,7 +73,13 @@ class AzureOpenAI(BaseOpenAI): azure_ad_token_provider: Union[Callable[[], str], None] = None """A function that returns an Azure Active Directory token. - Will be invoked on every request. + Will be invoked on every sync request. For async requests, + will be invoked if `azure_ad_async_token_provider` is not provided. + """ + azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + """A function that returns an Azure Active Directory token. + + Will be invoked on every async request. """ openai_api_type: Optional[str] = Field( default_factory=from_env("OPENAI_API_TYPE", default="azure") @@ -158,6 +164,12 @@ class AzureOpenAI(BaseOpenAI): ).completions if not self.async_client: async_specific = {"http_client": self.http_async_client} + + if self.azure_ad_async_token_provider: + client_params["azure_ad_token_provider"] = ( + self.azure_ad_async_token_provider + ) + self.async_client = openai.AsyncAzureOpenAI( **client_params, **async_specific, # type: ignore[arg-type]