mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 09:30:15 +00:00
partners/openai + community: Async Azure AD token provider support for Azure OpenAI (#27488)
This PR introduces a new `azure_ad_async_token_provider` attribute to the `AzureOpenAI` and `AzureChatOpenAI` classes in `partners/openai` and `community` packages, given it's currently supported on `openai` package as [AsyncAzureADTokenProvider](https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L33) type. The reason for creating a new attribute is to avoid breaking changes. Let's say you have an existing code that uses a `AzureOpenAI` or `AzureChatOpenAI` instance to perform both sync and async operations. The `azure_ad_token_provider` will work exactly as it is today, while `azure_ad_async_token_provider` will override it for async requests. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
parent
34684423bf
commit
ab205e7389
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||||
|
|
||||||
from langchain_core._api.deprecation import deprecated
|
from langchain_core._api.deprecation import deprecated
|
||||||
from langchain_core.outputs import ChatResult
|
from langchain_core.outputs import ChatResult
|
||||||
@ -90,7 +90,13 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||||
"""A function that returns an Azure Active Directory token.
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
Will be invoked on every request.
|
Will be invoked on every sync request. For async requests,
|
||||||
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||||
|
"""
|
||||||
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||||
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
|
Will be invoked on every async request.
|
||||||
"""
|
"""
|
||||||
model_version: str = ""
|
model_version: str = ""
|
||||||
"""Legacy, for openai<1.0.0 support."""
|
"""Legacy, for openai<1.0.0 support."""
|
||||||
@ -208,6 +214,12 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
"http_client": values["http_client"],
|
"http_client": values["http_client"],
|
||||||
}
|
}
|
||||||
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
|
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
|
||||||
|
|
||||||
|
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
|
||||||
|
|
||||||
|
if azure_ad_async_token_provider:
|
||||||
|
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
|
||||||
|
|
||||||
values["async_client"] = openai.AsyncAzureOpenAI(
|
values["async_client"] = openai.AsyncAzureOpenAI(
|
||||||
**client_params
|
**client_params
|
||||||
).chat.completions
|
).chat.completions
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Awaitable, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api.deprecation import deprecated
|
from langchain_core._api.deprecation import deprecated
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
@ -49,7 +49,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||||
"""A function that returns an Azure Active Directory token.
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
Will be invoked on every request.
|
Will be invoked on every sync request. For async requests,
|
||||||
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||||
|
"""
|
||||||
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||||
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
|
Will be invoked on every async request.
|
||||||
"""
|
"""
|
||||||
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
||||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||||
@ -162,6 +168,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
"http_client": self.http_client,
|
"http_client": self.http_client,
|
||||||
}
|
}
|
||||||
self.client = openai.AzureOpenAI(**client_params).embeddings
|
self.client = openai.AzureOpenAI(**client_params).embeddings
|
||||||
|
|
||||||
|
if self.azure_ad_async_token_provider:
|
||||||
|
client_params["azure_ad_token_provider"] = (
|
||||||
|
self.azure_ad_async_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
|
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
|
||||||
else:
|
else:
|
||||||
self.client = openai.Embedding
|
self.client = openai.Embedding
|
||||||
|
@ -8,6 +8,7 @@ from typing import (
|
|||||||
AbstractSet,
|
AbstractSet,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
@ -804,7 +805,13 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||||
"""A function that returns an Azure Active Directory token.
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
Will be invoked on every request.
|
Will be invoked on every sync request. For async requests,
|
||||||
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||||
|
"""
|
||||||
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||||
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
|
Will be invoked on every async request.
|
||||||
"""
|
"""
|
||||||
openai_api_type: str = ""
|
openai_api_type: str = ""
|
||||||
"""Legacy, for openai<1.0.0 support."""
|
"""Legacy, for openai<1.0.0 support."""
|
||||||
@ -922,6 +929,12 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
"http_client": values["http_client"],
|
"http_client": values["http_client"],
|
||||||
}
|
}
|
||||||
values["client"] = openai.AzureOpenAI(**client_params).completions
|
values["client"] = openai.AzureOpenAI(**client_params).completions
|
||||||
|
|
||||||
|
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
|
||||||
|
|
||||||
|
if azure_ad_async_token_provider:
|
||||||
|
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
|
||||||
|
|
||||||
values["async_client"] = openai.AsyncAzureOpenAI(
|
values["async_client"] = openai.AsyncAzureOpenAI(
|
||||||
**client_params
|
**client_params
|
||||||
).completions
|
).completions
|
||||||
|
@ -4,7 +4,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, TypeVar, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypedDict,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from langchain_core.language_models.chat_models import LangSmithParams
|
from langchain_core.language_models.chat_models import LangSmithParams
|
||||||
@ -494,7 +505,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||||
"""A function that returns an Azure Active Directory token.
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
Will be invoked on every request.
|
Will be invoked on every sync request. For async requests,
|
||||||
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||||
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
|
Will be invoked on every async request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_version: str = ""
|
model_version: str = ""
|
||||||
@ -633,6 +651,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
self.client = self.root_client.chat.completions
|
self.client = self.root_client.chat.completions
|
||||||
if not self.async_client:
|
if not self.async_client:
|
||||||
async_specific = {"http_client": self.http_async_client}
|
async_specific = {"http_client": self.http_async_client}
|
||||||
|
|
||||||
|
if self.azure_ad_async_token_provider:
|
||||||
|
client_params["azure_ad_token_provider"] = (
|
||||||
|
self.azure_ad_async_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
self.root_async_client = openai.AsyncAzureOpenAI(
|
self.root_async_client = openai.AsyncAzureOpenAI(
|
||||||
**client_params,
|
**client_params,
|
||||||
**async_specific, # type: ignore[arg-type]
|
**async_specific, # type: ignore[arg-type]
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Callable, Optional, Union
|
from typing import Awaitable, Callable, Optional, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from langchain_core.utils import from_env, secret_from_env
|
from langchain_core.utils import from_env, secret_from_env
|
||||||
@ -146,7 +146,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||||
"""A function that returns an Azure Active Directory token.
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
Will be invoked on every request.
|
Will be invoked on every sync request. For async requests,
|
||||||
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||||
|
"""
|
||||||
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||||
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
|
Will be invoked on every async request.
|
||||||
"""
|
"""
|
||||||
openai_api_type: Optional[str] = Field(
|
openai_api_type: Optional[str] = Field(
|
||||||
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
||||||
@ -203,6 +209,12 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|||||||
).embeddings
|
).embeddings
|
||||||
if not self.async_client:
|
if not self.async_client:
|
||||||
async_specific: dict = {"http_client": self.http_async_client}
|
async_specific: dict = {"http_client": self.http_async_client}
|
||||||
|
|
||||||
|
if self.azure_ad_async_token_provider:
|
||||||
|
client_params["azure_ad_token_provider"] = (
|
||||||
|
self.azure_ad_async_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
self.async_client = openai.AsyncAzureOpenAI(
|
self.async_client = openai.AsyncAzureOpenAI(
|
||||||
**client_params, # type: ignore[arg-type]
|
**client_params, # type: ignore[arg-type]
|
||||||
**async_specific,
|
**async_specific,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from langchain_core.language_models import LangSmithParams
|
from langchain_core.language_models import LangSmithParams
|
||||||
@ -73,7 +73,13 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||||
"""A function that returns an Azure Active Directory token.
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
Will be invoked on every request.
|
Will be invoked on every sync request. For async requests,
|
||||||
|
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||||
|
"""
|
||||||
|
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||||
|
"""A function that returns an Azure Active Directory token.
|
||||||
|
|
||||||
|
Will be invoked on every async request.
|
||||||
"""
|
"""
|
||||||
openai_api_type: Optional[str] = Field(
|
openai_api_type: Optional[str] = Field(
|
||||||
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
||||||
@ -158,6 +164,12 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
).completions
|
).completions
|
||||||
if not self.async_client:
|
if not self.async_client:
|
||||||
async_specific = {"http_client": self.http_async_client}
|
async_specific = {"http_client": self.http_async_client}
|
||||||
|
|
||||||
|
if self.azure_ad_async_token_provider:
|
||||||
|
client_params["azure_ad_token_provider"] = (
|
||||||
|
self.azure_ad_async_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
self.async_client = openai.AsyncAzureOpenAI(
|
self.async_client = openai.AsyncAzureOpenAI(
|
||||||
**client_params,
|
**client_params,
|
||||||
**async_specific, # type: ignore[arg-type]
|
**async_specific, # type: ignore[arg-type]
|
||||||
|
Loading…
Reference in New Issue
Block a user