community(azuresearch): allow to use any valid credential (#28873)

Add option to use any valid credential type.
Differentiates async cases needed by Azure Search.

This could replace the use of a static token
This commit is contained in:
Adrián Panella 2024-12-23 10:05:48 -05:00 committed by GitHub
parent 4b4d09f82b
commit 8d9907088b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,6 +42,8 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger()
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.search.documents import SearchClient, SearchItemPaged
from azure.search.documents.aio import (
AsyncSearchItemPaged,
@ -96,10 +98,13 @@ def _get_search_client(
cors_options: Optional[CorsOptions] = None,
async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
azure_credential: Optional[TokenCredential] = None,
azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.aio import SearchClient as AsyncSearchClient
from azure.search.documents.indexes import SearchIndexClient
@ -143,12 +148,17 @@ def _get_search_client(
if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
async_credential = credential
else:
credential = AzureKeyCredential(key)
async_credential = credential
elif azure_ad_access_token is not None:
credential = AzureBearerTokenCredential(azure_ad_access_token)
async_credential = credential
else:
credential = DefaultAzureCredential()
credential = azure_credential or DefaultAzureCredential()
async_credential = azure_async_credential or AsyncDefaultAzureCredential()
index_client: SearchIndexClient = SearchIndexClient(
endpoint=endpoint,
credential=credential,
@ -266,7 +276,7 @@ def _get_search_client(
return AsyncSearchClient(
endpoint=endpoint,
index_name=index_name,
credential=credential,
credential=async_credential,
user_agent=user_agent,
**additional_search_client_options,
)
@ -278,7 +288,7 @@ class AzureSearch(VectorStore):
def __init__(
self,
azure_search_endpoint: str,
azure_search_key: str,
azure_search_key: Optional[str],
index_name: str,
embedding_function: Union[Callable, Embeddings],
search_type: str = "hybrid",
@ -295,6 +305,8 @@ class AzureSearch(VectorStore):
vector_search_dimensions: Optional[int] = None,
additional_search_client_options: Optional[Dict[str, Any]] = None,
azure_ad_access_token: Optional[str] = None,
azure_credential: Optional[TokenCredential] = None,
azure_async_credential: Optional[AsyncTokenCredential] = None,
**kwargs: Any,
):
try:
@ -361,6 +373,7 @@ class AzureSearch(VectorStore):
user_agent=user_agent,
cors_options=cors_options,
additional_search_client_options=additional_search_client_options,
azure_credential=azure_credential,
)
self.async_client = _get_search_client(
azure_search_endpoint,
@ -377,6 +390,8 @@ class AzureSearch(VectorStore):
user_agent=user_agent,
cors_options=cors_options,
async_=True,
azure_credential=azure_credential,
azure_async_credential=azure_async_credential,
)
self.search_type = search_type
self.semantic_configuration_name = semantic_configuration_name