community: Azure Search Vector Store is missing Access Token Authentication (#24330)

Added Azure Search Access Token Authentication instead of API KEY auth.
Fixes Issue: https://github.com/langchain-ai/langchain/issues/24263
Dependencies: None
Twitter: @levalencia

@baskaryan

Could you please review? First time creating a PR that fixes some code.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Luis Valencia
2024-08-27 00:41:50 +02:00
committed by GitHub
parent 49b0bc7b5a
commit 99f9a664a5
4 changed files with 34 additions and 11 deletions

View File

@@ -5,6 +5,7 @@ import base64
import itertools
import json
import logging
import time
import uuid
from typing import (
TYPE_CHECKING,
@@ -79,8 +80,9 @@ MAX_UPLOAD_BATCH_SIZE = 1000
def _get_search_client(
endpoint: str,
key: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
@@ -95,7 +97,7 @@ def _get_search_client(
async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.search.documents import SearchClient
@@ -119,13 +121,23 @@ def _get_search_client(
additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or []
if key is None:
credential = DefaultAzureCredential()
elif key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
# Determine the appropriate credential to use
if key is not None:
if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
else:
credential = AzureKeyCredential(key)
elif azure_ad_access_token is not None:
credential = TokenCredential(
lambda *scopes, **kwargs: AccessToken(
azure_ad_access_token, int(time.time()) + 3600
)
)
else:
credential = AzureKeyCredential(key)
credential = DefaultAzureCredential()
index_client: SearchIndexClient = SearchIndexClient(
endpoint=endpoint, credential=credential, user_agent=user_agent
)
@@ -253,6 +265,7 @@ class AzureSearch(VectorStore):
self,
azure_search_endpoint: str,
azure_search_key: str,
azure_ad_access_token: Optional[str],
index_name: str,
embedding_function: Union[Callable, Embeddings],
search_type: str = "hybrid",
@@ -321,8 +334,9 @@ class AzureSearch(VectorStore):
user_agent += " " + kwargs["user_agent"]
self.client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
@@ -336,8 +350,9 @@ class AzureSearch(VectorStore):
)
self.async_client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
@@ -1387,6 +1402,7 @@ class AzureSearch(VectorStore):
metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "",
azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None,
**kwargs: Any,
@@ -1395,6 +1411,7 @@ class AzureSearch(VectorStore):
azure_search = cls(
azure_search_endpoint,
azure_search_key,
azure_ad_access_token,
index_name,
embedding,
fields=fields,
@@ -1411,6 +1428,7 @@ class AzureSearch(VectorStore):
metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "",
azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None,
**kwargs: Any,
@@ -1419,6 +1437,7 @@ class AzureSearch(VectorStore):
azure_search = cls(
azure_search_endpoint,
azure_search_key,
azure_ad_access_token,
index_name,
embedding,
fields=fields,