community: Azure Search Vector Store is missing Access Token Authentication (#24330)

Added Azure Search Access Token Authentication instead of API KEY auth.
Fixes Issue: https://github.com/langchain-ai/langchain/issues/24263
Dependencies: None
Twitter: @levalencia

@baskaryan

Could you please review? First time creating a PR that fixes some code.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Luis Valencia 2024-08-27 00:41:50 +02:00 committed by GitHub
parent 49b0bc7b5a
commit 99f9a664a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 11 deletions

View File

@ -16,7 +16,6 @@ cloudpickle>=2.0.0
cohere>=4,<6 cohere>=4,<6
databricks-vectorsearch>=0.21,<0.22 databricks-vectorsearch>=0.21,<0.22
datasets>=2.15.0,<3 datasets>=2.15.0,<3
dedoc>=2.2.6,<3
dgml-utils>=0.3.0,<0.4 dgml-utils>=0.3.0,<0.4
elasticsearch>=8.12.0,<9 elasticsearch>=8.12.0,<9
esprima>=4.0.1,<5 esprima>=4.0.1,<5

View File

@ -5,6 +5,7 @@ import base64
import itertools import itertools
import json import json
import logging import logging
import time
import uuid import uuid
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -79,8 +80,9 @@ MAX_UPLOAD_BATCH_SIZE = 1000
def _get_search_client( def _get_search_client(
endpoint: str, endpoint: str,
key: str,
index_name: str, index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None, semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None, fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None, vector_search: Optional[VectorSearch] = None,
@ -95,7 +97,7 @@ def _get_search_client(
async_: bool = False, async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None, additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> Union[SearchClient, AsyncSearchClient]: ) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.exceptions import ResourceNotFoundError from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.search.documents import SearchClient from azure.search.documents import SearchClient
@ -119,13 +121,23 @@ def _get_search_client(
additional_search_client_options = additional_search_client_options or {} additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or [] default_fields = default_fields or []
if key is None: credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
credential = DefaultAzureCredential()
elif key.upper() == "INTERACTIVE": # Determine the appropriate credential to use
credential = InteractiveBrowserCredential() if key is not None:
credential.get_token("https://search.azure.com/.default") if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
else:
credential = AzureKeyCredential(key)
elif azure_ad_access_token is not None:
credential = TokenCredential(
lambda *scopes, **kwargs: AccessToken(
azure_ad_access_token, int(time.time()) + 3600
)
)
else: else:
credential = AzureKeyCredential(key) credential = DefaultAzureCredential()
index_client: SearchIndexClient = SearchIndexClient( index_client: SearchIndexClient = SearchIndexClient(
endpoint=endpoint, credential=credential, user_agent=user_agent endpoint=endpoint, credential=credential, user_agent=user_agent
) )
@ -253,6 +265,7 @@ class AzureSearch(VectorStore):
self, self,
azure_search_endpoint: str, azure_search_endpoint: str,
azure_search_key: str, azure_search_key: str,
azure_ad_access_token: Optional[str],
index_name: str, index_name: str,
embedding_function: Union[Callable, Embeddings], embedding_function: Union[Callable, Embeddings],
search_type: str = "hybrid", search_type: str = "hybrid",
@ -321,8 +334,9 @@ class AzureSearch(VectorStore):
user_agent += " " + kwargs["user_agent"] user_agent += " " + kwargs["user_agent"]
self.client = _get_search_client( self.client = _get_search_client(
azure_search_endpoint, azure_search_endpoint,
azure_search_key,
index_name, index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name, semantic_configuration_name=semantic_configuration_name,
fields=fields, fields=fields,
vector_search=vector_search, vector_search=vector_search,
@ -336,8 +350,9 @@ class AzureSearch(VectorStore):
) )
self.async_client = _get_search_client( self.async_client = _get_search_client(
azure_search_endpoint, azure_search_endpoint,
azure_search_key,
index_name, index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name, semantic_configuration_name=semantic_configuration_name,
fields=fields, fields=fields,
vector_search=vector_search, vector_search=vector_search,
@ -1387,6 +1402,7 @@ class AzureSearch(VectorStore):
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "", azure_search_endpoint: str = "",
azure_search_key: str = "", azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index", index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None, fields: Optional[List[SearchField]] = None,
**kwargs: Any, **kwargs: Any,
@ -1395,6 +1411,7 @@ class AzureSearch(VectorStore):
azure_search = cls( azure_search = cls(
azure_search_endpoint, azure_search_endpoint,
azure_search_key, azure_search_key,
azure_ad_access_token,
index_name, index_name,
embedding, embedding,
fields=fields, fields=fields,
@ -1411,6 +1428,7 @@ class AzureSearch(VectorStore):
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "", azure_search_endpoint: str = "",
azure_search_key: str = "", azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index", index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None, fields: Optional[List[SearchField]] = None,
**kwargs: Any, **kwargs: Any,
@ -1419,6 +1437,7 @@ class AzureSearch(VectorStore):
azure_search = cls( azure_search = cls(
azure_search_endpoint, azure_search_endpoint,
azure_search_key, azure_search_key,
azure_ad_access_token,
index_name, index_name,
embedding, embedding,
fields=fields, fields=fields,

View File

@ -13,6 +13,7 @@ model = os.getenv("OPENAI_EMBEDDINGS_ENGINE_DOC", "text-embedding-ada-002")
# Vector store settings # Vector store settings
vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "") vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "")
vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "") vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "")
access_token: str = os.getenv("AZURE_SEARCH_ACCESS_TOKEN", "")
index_name: str = "embeddings-vector-store-test" index_name: str = "embeddings-vector-store-test"
@ -25,6 +26,7 @@ def similarity_search_test() -> None:
vector_store: AzureSearch = AzureSearch( vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=vector_store_address, azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password, azure_search_key=vector_store_password,
azure_ad_access_token=access_token,
index_name=index_name, index_name=index_name,
embedding_function=embeddings.embed_query, embedding_function=embeddings.embed_query,
) )
@ -68,6 +70,7 @@ def test_semantic_hybrid_search() -> None:
vector_store: AzureSearch = AzureSearch( vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=vector_store_address, azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password, azure_search_key=vector_store_password,
azure_ad_access_token=access_token,
index_name=index_name, index_name=index_name,
embedding_function=embeddings.embed_query, embedding_function=embeddings.embed_query,
semantic_configuration_name="default", semantic_configuration_name="default",

View File

@ -32,6 +32,7 @@ class FakeEmbeddingsWithDimension(FakeEmbeddings):
DEFAULT_INDEX_NAME = "langchain-index" DEFAULT_INDEX_NAME = "langchain-index"
DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net" DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net"
DEFAULT_KEY = "mykey" DEFAULT_KEY = "mykey"
DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension() DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
@ -127,6 +128,7 @@ def create_vector_store(
return AzureSearch( return AzureSearch(
azure_search_endpoint=DEFAULT_ENDPOINT, azure_search_endpoint=DEFAULT_ENDPOINT,
azure_search_key=DEFAULT_KEY, azure_search_key=DEFAULT_KEY,
azure_ad_access_token=DEFAULT_ACCESS_TOKEN,
index_name=DEFAULT_INDEX_NAME, index_name=DEFAULT_INDEX_NAME,
embedding_function=DEFAULT_EMBEDDING_MODEL, embedding_function=DEFAULT_EMBEDDING_MODEL,
additional_search_client_options=additional_search_client_options, additional_search_client_options=additional_search_client_options,