community: Azure Search Vector Store is missing Access Token Authentication (#24330)

Added Azure Search Access Token Authentication instead of API KEY auth.
Fixes Issue: https://github.com/langchain-ai/langchain/issues/24263
Dependencies: None
Twitter: @levalencia

@baskaryan

Could you please review? First time creating a PR that fixes some code.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Luis Valencia 2024-08-27 00:41:50 +02:00 committed by GitHub
parent 49b0bc7b5a
commit 99f9a664a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 11 deletions

View File

@ -16,7 +16,6 @@ cloudpickle>=2.0.0
cohere>=4,<6
databricks-vectorsearch>=0.21,<0.22
datasets>=2.15.0,<3
dedoc>=2.2.6,<3
dgml-utils>=0.3.0,<0.4
elasticsearch>=8.12.0,<9
esprima>=4.0.1,<5

View File

@ -5,6 +5,7 @@ import base64
import itertools
import json
import logging
import time
import uuid
from typing import (
TYPE_CHECKING,
@ -79,8 +80,9 @@ MAX_UPLOAD_BATCH_SIZE = 1000
def _get_search_client(
endpoint: str,
key: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
@ -95,7 +97,7 @@ def _get_search_client(
async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.search.documents import SearchClient
@ -119,13 +121,23 @@ def _get_search_client(
additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or []
if key is None:
credential = DefaultAzureCredential()
elif key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
# Determine the appropriate credential to use
if key is not None:
if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
else:
credential = AzureKeyCredential(key)
elif azure_ad_access_token is not None:
credential = TokenCredential(
lambda *scopes, **kwargs: AccessToken(
azure_ad_access_token, int(time.time()) + 3600
)
)
else:
credential = AzureKeyCredential(key)
credential = DefaultAzureCredential()
index_client: SearchIndexClient = SearchIndexClient(
endpoint=endpoint, credential=credential, user_agent=user_agent
)
@ -253,6 +265,7 @@ class AzureSearch(VectorStore):
self,
azure_search_endpoint: str,
azure_search_key: str,
azure_ad_access_token: Optional[str],
index_name: str,
embedding_function: Union[Callable, Embeddings],
search_type: str = "hybrid",
@ -321,8 +334,9 @@ class AzureSearch(VectorStore):
user_agent += " " + kwargs["user_agent"]
self.client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
@ -336,8 +350,9 @@ class AzureSearch(VectorStore):
)
self.async_client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
@ -1387,6 +1402,7 @@ class AzureSearch(VectorStore):
metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "",
azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None,
**kwargs: Any,
@ -1395,6 +1411,7 @@ class AzureSearch(VectorStore):
azure_search = cls(
azure_search_endpoint,
azure_search_key,
azure_ad_access_token,
index_name,
embedding,
fields=fields,
@ -1411,6 +1428,7 @@ class AzureSearch(VectorStore):
metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "",
azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None,
**kwargs: Any,
@ -1419,6 +1437,7 @@ class AzureSearch(VectorStore):
azure_search = cls(
azure_search_endpoint,
azure_search_key,
azure_ad_access_token,
index_name,
embedding,
fields=fields,

View File

@ -13,6 +13,7 @@ model = os.getenv("OPENAI_EMBEDDINGS_ENGINE_DOC", "text-embedding-ada-002")
# Vector store settings
vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "")
vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "")
access_token: str = os.getenv("AZURE_SEARCH_ACCESS_TOKEN", "")
index_name: str = "embeddings-vector-store-test"
@ -25,6 +26,7 @@ def similarity_search_test() -> None:
vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password,
azure_ad_access_token=access_token,
index_name=index_name,
embedding_function=embeddings.embed_query,
)
@ -68,6 +70,7 @@ def test_semantic_hybrid_search() -> None:
vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password,
azure_ad_access_token=access_token,
index_name=index_name,
embedding_function=embeddings.embed_query,
semantic_configuration_name="default",

View File

@ -32,6 +32,7 @@ class FakeEmbeddingsWithDimension(FakeEmbeddings):
DEFAULT_INDEX_NAME = "langchain-index"
DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net"
DEFAULT_KEY = "mykey"
DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
@ -127,6 +128,7 @@ def create_vector_store(
return AzureSearch(
azure_search_endpoint=DEFAULT_ENDPOINT,
azure_search_key=DEFAULT_KEY,
azure_ad_access_token=DEFAULT_ACCESS_TOKEN,
index_name=DEFAULT_INDEX_NAME,
embedding_function=DEFAULT_EMBEDDING_MODEL,
additional_search_client_options=additional_search_client_options,