From fb44e74ca4c64e2892959ffd9111e5a7275e20f0 Mon Sep 17 00:00:00 2001 From: clairebehue <87328722+clairebehue@users.noreply.github.com> Date: Mon, 16 Dec 2024 06:56:45 +0100 Subject: [PATCH] community: fix AzureSearch Oauth with azure_ad_access_token (#26995) **Description:** AzureSearch vector store: create a wrapper class on `azure.core.credentials.TokenCredential` (which is not-instantiable) to fix Oauth usage with `azure_ad_access_token` argument **Issue:** [the issue it fixes](https://github.com/langchain-ai/langchain/issues/26216) **Dependencies:** None - [x] **Lint and test** --------- Co-authored-by: Erick Friis --- .../vectorstores/azuresearch.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 6d19574e8ce..6748f81dc25 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -119,6 +119,21 @@ def _get_search_client( VectorSearchProfile, ) + class AzureBearerTokenCredential(TokenCredential): + def __init__(self, token: str): + # set the expiry to an hour from now. + self._token = AccessToken(token, int(time.time()) + 3600) + + def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + return self._token + additional_search_client_options = additional_search_client_options or {} default_fields = default_fields or [] credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential] @@ -131,11 +146,7 @@ def _get_search_client( else: credential = AzureKeyCredential(key) elif azure_ad_access_token is not None: - credential = TokenCredential( - lambda *scopes, **kwargs: AccessToken( - azure_ad_access_token, int(time.time()) + 3600 - ) - ) + credential = AzureBearerTokenCredential(azure_ad_access_token) else: credential = DefaultAzureCredential() index_client: SearchIndexClient = SearchIndexClient(