mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
community: Azure Search Vector Store is missing Access Token Authentication (#24330)
Added Azure Search Access Token Authentication instead of API KEY auth. Fixes Issue: https://github.com/langchain-ai/langchain/issues/24263 Dependencies: None Twitter: @levalencia @baskaryan Could you please review? First time creating a PR that fixes some code. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
49b0bc7b5a
commit
99f9a664a5
@ -16,7 +16,6 @@ cloudpickle>=2.0.0
|
||||
cohere>=4,<6
|
||||
databricks-vectorsearch>=0.21,<0.22
|
||||
datasets>=2.15.0,<3
|
||||
dedoc>=2.2.6,<3
|
||||
dgml-utils>=0.3.0,<0.4
|
||||
elasticsearch>=8.12.0,<9
|
||||
esprima>=4.0.1,<5
|
||||
|
@ -5,6 +5,7 @@ import base64
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -79,8 +80,9 @@ MAX_UPLOAD_BATCH_SIZE = 1000
|
||||
|
||||
def _get_search_client(
|
||||
endpoint: str,
|
||||
key: str,
|
||||
index_name: str,
|
||||
key: Optional[str] = None,
|
||||
azure_ad_access_token: Optional[str] = None,
|
||||
semantic_configuration_name: Optional[str] = None,
|
||||
fields: Optional[List[SearchField]] = None,
|
||||
vector_search: Optional[VectorSearch] = None,
|
||||
@ -95,7 +97,7 @@ def _get_search_client(
|
||||
async_: bool = False,
|
||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[SearchClient, AsyncSearchClient]:
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
|
||||
from azure.search.documents import SearchClient
|
||||
@ -119,13 +121,23 @@ def _get_search_client(
|
||||
|
||||
additional_search_client_options = additional_search_client_options or {}
|
||||
default_fields = default_fields or []
|
||||
if key is None:
|
||||
credential = DefaultAzureCredential()
|
||||
elif key.upper() == "INTERACTIVE":
|
||||
credential = InteractiveBrowserCredential()
|
||||
credential.get_token("https://search.azure.com/.default")
|
||||
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
|
||||
|
||||
# Determine the appropriate credential to use
|
||||
if key is not None:
|
||||
if key.upper() == "INTERACTIVE":
|
||||
credential = InteractiveBrowserCredential()
|
||||
credential.get_token("https://search.azure.com/.default")
|
||||
else:
|
||||
credential = AzureKeyCredential(key)
|
||||
elif azure_ad_access_token is not None:
|
||||
credential = TokenCredential(
|
||||
lambda *scopes, **kwargs: AccessToken(
|
||||
azure_ad_access_token, int(time.time()) + 3600
|
||||
)
|
||||
)
|
||||
else:
|
||||
credential = AzureKeyCredential(key)
|
||||
credential = DefaultAzureCredential()
|
||||
index_client: SearchIndexClient = SearchIndexClient(
|
||||
endpoint=endpoint, credential=credential, user_agent=user_agent
|
||||
)
|
||||
@ -253,6 +265,7 @@ class AzureSearch(VectorStore):
|
||||
self,
|
||||
azure_search_endpoint: str,
|
||||
azure_search_key: str,
|
||||
azure_ad_access_token: Optional[str],
|
||||
index_name: str,
|
||||
embedding_function: Union[Callable, Embeddings],
|
||||
search_type: str = "hybrid",
|
||||
@ -321,8 +334,9 @@ class AzureSearch(VectorStore):
|
||||
user_agent += " " + kwargs["user_agent"]
|
||||
self.client = _get_search_client(
|
||||
azure_search_endpoint,
|
||||
azure_search_key,
|
||||
index_name,
|
||||
azure_search_key,
|
||||
azure_ad_access_token,
|
||||
semantic_configuration_name=semantic_configuration_name,
|
||||
fields=fields,
|
||||
vector_search=vector_search,
|
||||
@ -336,8 +350,9 @@ class AzureSearch(VectorStore):
|
||||
)
|
||||
self.async_client = _get_search_client(
|
||||
azure_search_endpoint,
|
||||
azure_search_key,
|
||||
index_name,
|
||||
azure_search_key,
|
||||
azure_ad_access_token,
|
||||
semantic_configuration_name=semantic_configuration_name,
|
||||
fields=fields,
|
||||
vector_search=vector_search,
|
||||
@ -1387,6 +1402,7 @@ class AzureSearch(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
azure_search_endpoint: str = "",
|
||||
azure_search_key: str = "",
|
||||
azure_ad_access_token: Optional[str] = None,
|
||||
index_name: str = "langchain-index",
|
||||
fields: Optional[List[SearchField]] = None,
|
||||
**kwargs: Any,
|
||||
@ -1395,6 +1411,7 @@ class AzureSearch(VectorStore):
|
||||
azure_search = cls(
|
||||
azure_search_endpoint,
|
||||
azure_search_key,
|
||||
azure_ad_access_token,
|
||||
index_name,
|
||||
embedding,
|
||||
fields=fields,
|
||||
@ -1411,6 +1428,7 @@ class AzureSearch(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
azure_search_endpoint: str = "",
|
||||
azure_search_key: str = "",
|
||||
azure_ad_access_token: Optional[str] = None,
|
||||
index_name: str = "langchain-index",
|
||||
fields: Optional[List[SearchField]] = None,
|
||||
**kwargs: Any,
|
||||
@ -1419,6 +1437,7 @@ class AzureSearch(VectorStore):
|
||||
azure_search = cls(
|
||||
azure_search_endpoint,
|
||||
azure_search_key,
|
||||
azure_ad_access_token,
|
||||
index_name,
|
||||
embedding,
|
||||
fields=fields,
|
||||
|
@ -13,6 +13,7 @@ model = os.getenv("OPENAI_EMBEDDINGS_ENGINE_DOC", "text-embedding-ada-002")
|
||||
# Vector store settings
|
||||
vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "")
|
||||
vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "")
|
||||
access_token: str = os.getenv("AZURE_SEARCH_ACCESS_TOKEN", "")
|
||||
index_name: str = "embeddings-vector-store-test"
|
||||
|
||||
|
||||
@ -25,6 +26,7 @@ def similarity_search_test() -> None:
|
||||
vector_store: AzureSearch = AzureSearch(
|
||||
azure_search_endpoint=vector_store_address,
|
||||
azure_search_key=vector_store_password,
|
||||
azure_ad_access_token=access_token,
|
||||
index_name=index_name,
|
||||
embedding_function=embeddings.embed_query,
|
||||
)
|
||||
@ -68,6 +70,7 @@ def test_semantic_hybrid_search() -> None:
|
||||
vector_store: AzureSearch = AzureSearch(
|
||||
azure_search_endpoint=vector_store_address,
|
||||
azure_search_key=vector_store_password,
|
||||
azure_ad_access_token=access_token,
|
||||
index_name=index_name,
|
||||
embedding_function=embeddings.embed_query,
|
||||
semantic_configuration_name="default",
|
||||
|
@ -32,6 +32,7 @@ class FakeEmbeddingsWithDimension(FakeEmbeddings):
|
||||
DEFAULT_INDEX_NAME = "langchain-index"
|
||||
DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net"
|
||||
DEFAULT_KEY = "mykey"
|
||||
DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
|
||||
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
|
||||
|
||||
|
||||
@ -127,6 +128,7 @@ def create_vector_store(
|
||||
return AzureSearch(
|
||||
azure_search_endpoint=DEFAULT_ENDPOINT,
|
||||
azure_search_key=DEFAULT_KEY,
|
||||
azure_ad_access_token=DEFAULT_ACCESS_TOKEN,
|
||||
index_name=DEFAULT_INDEX_NAME,
|
||||
embedding_function=DEFAULT_EMBEDDING_MODEL,
|
||||
additional_search_client_options=additional_search_client_options,
|
||||
|
Loading…
Reference in New Issue
Block a user