From 8d9907088b843756b5aa3f49f11f51b451567fa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Panella?= Date: Mon, 23 Dec 2024 10:05:48 -0500 Subject: [PATCH] community(azuresearch): allow to use any valid credential (#28873) Add option to use any valid credential type. Differentiates async cases needed by Azure Search. This could replace the use of a static token --- .../vectorstores/azuresearch.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 6930c8319e4..d0aa15e2acb 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -42,6 +42,8 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger() if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from azure.search.documents import SearchClient, SearchItemPaged from azure.search.documents.aio import ( AsyncSearchItemPaged, @@ -96,10 +98,13 @@ def _get_search_client( cors_options: Optional[CorsOptions] = None, async_: bool = False, additional_search_client_options: Optional[Dict[str, Any]] = None, + azure_credential: Optional[TokenCredential] = None, + azure_async_credential: Optional[AsyncTokenCredential] = None, ) -> Union[SearchClient, AsyncSearchClient]: from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential + from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential from azure.search.documents import SearchClient from azure.search.documents.aio import SearchClient as AsyncSearchClient from azure.search.documents.indexes import SearchIndexClient @@ -143,12 +148,17 @@ def _get_search_client( if key.upper() == "INTERACTIVE": credential = InteractiveBrowserCredential() credential.get_token("https://search.azure.com/.default") + async_credential = credential else: credential = AzureKeyCredential(key) + async_credential = credential elif azure_ad_access_token is not None: credential = AzureBearerTokenCredential(azure_ad_access_token) + async_credential = credential else: - credential = DefaultAzureCredential() + credential = azure_credential or DefaultAzureCredential() + async_credential = azure_async_credential or AsyncDefaultAzureCredential() + index_client: SearchIndexClient = SearchIndexClient( endpoint=endpoint, credential=credential, @@ -266,7 +276,7 @@ def _get_search_client( return AsyncSearchClient( endpoint=endpoint, index_name=index_name, - credential=credential, + credential=async_credential, user_agent=user_agent, **additional_search_client_options, ) @@ -278,7 +288,7 @@ class AzureSearch(VectorStore): def __init__( self, azure_search_endpoint: str, - azure_search_key: str, + azure_search_key: Optional[str], index_name: str, embedding_function: Union[Callable, Embeddings], search_type: str = "hybrid", @@ -295,6 +305,8 @@ class AzureSearch(VectorStore): vector_search_dimensions: Optional[int] = None, additional_search_client_options: Optional[Dict[str, Any]] = None, azure_ad_access_token: Optional[str] = None, + azure_credential: Optional[TokenCredential] = None, + azure_async_credential: Optional[AsyncTokenCredential] = None, **kwargs: Any, ): try: @@ -361,6 +373,7 @@ class AzureSearch(VectorStore): user_agent=user_agent, cors_options=cors_options, additional_search_client_options=additional_search_client_options, + azure_credential=azure_credential, ) self.async_client = _get_search_client( azure_search_endpoint, @@ -377,6 +390,8 @@ class AzureSearch(VectorStore): user_agent=user_agent, cors_options=cors_options, async_=True, + azure_credential=azure_credential, + azure_async_credential=azure_async_credential, ) self.search_type = search_type self.semantic_configuration_name = semantic_configuration_name