mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
community: Azure Search Vector Store is missing Access Token Authentication (#24330)
Added Azure Search Access Token Authentication instead of API KEY auth. Fixes Issue: https://github.com/langchain-ai/langchain/issues/24263 Dependencies: None Twitter: @levalencia @baskaryan Could you please review? First time creating a PR that fixes some code. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
49b0bc7b5a
commit
99f9a664a5
@ -16,7 +16,6 @@ cloudpickle>=2.0.0
|
|||||||
cohere>=4,<6
|
cohere>=4,<6
|
||||||
databricks-vectorsearch>=0.21,<0.22
|
databricks-vectorsearch>=0.21,<0.22
|
||||||
datasets>=2.15.0,<3
|
datasets>=2.15.0,<3
|
||||||
dedoc>=2.2.6,<3
|
|
||||||
dgml-utils>=0.3.0,<0.4
|
dgml-utils>=0.3.0,<0.4
|
||||||
elasticsearch>=8.12.0,<9
|
elasticsearch>=8.12.0,<9
|
||||||
esprima>=4.0.1,<5
|
esprima>=4.0.1,<5
|
||||||
|
@ -5,6 +5,7 @@ import base64
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -79,8 +80,9 @@ MAX_UPLOAD_BATCH_SIZE = 1000
|
|||||||
|
|
||||||
def _get_search_client(
|
def _get_search_client(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
key: str,
|
|
||||||
index_name: str,
|
index_name: str,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
azure_ad_access_token: Optional[str] = None,
|
||||||
semantic_configuration_name: Optional[str] = None,
|
semantic_configuration_name: Optional[str] = None,
|
||||||
fields: Optional[List[SearchField]] = None,
|
fields: Optional[List[SearchField]] = None,
|
||||||
vector_search: Optional[VectorSearch] = None,
|
vector_search: Optional[VectorSearch] = None,
|
||||||
@ -95,7 +97,7 @@ def _get_search_client(
|
|||||||
async_: bool = False,
|
async_: bool = False,
|
||||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||||
) -> Union[SearchClient, AsyncSearchClient]:
|
) -> Union[SearchClient, AsyncSearchClient]:
|
||||||
from azure.core.credentials import AzureKeyCredential
|
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
|
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
|
||||||
from azure.search.documents import SearchClient
|
from azure.search.documents import SearchClient
|
||||||
@ -119,13 +121,23 @@ def _get_search_client(
|
|||||||
|
|
||||||
additional_search_client_options = additional_search_client_options or {}
|
additional_search_client_options = additional_search_client_options or {}
|
||||||
default_fields = default_fields or []
|
default_fields = default_fields or []
|
||||||
if key is None:
|
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
|
||||||
credential = DefaultAzureCredential()
|
|
||||||
elif key.upper() == "INTERACTIVE":
|
# Determine the appropriate credential to use
|
||||||
credential = InteractiveBrowserCredential()
|
if key is not None:
|
||||||
credential.get_token("https://search.azure.com/.default")
|
if key.upper() == "INTERACTIVE":
|
||||||
|
credential = InteractiveBrowserCredential()
|
||||||
|
credential.get_token("https://search.azure.com/.default")
|
||||||
|
else:
|
||||||
|
credential = AzureKeyCredential(key)
|
||||||
|
elif azure_ad_access_token is not None:
|
||||||
|
credential = TokenCredential(
|
||||||
|
lambda *scopes, **kwargs: AccessToken(
|
||||||
|
azure_ad_access_token, int(time.time()) + 3600
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
credential = AzureKeyCredential(key)
|
credential = DefaultAzureCredential()
|
||||||
index_client: SearchIndexClient = SearchIndexClient(
|
index_client: SearchIndexClient = SearchIndexClient(
|
||||||
endpoint=endpoint, credential=credential, user_agent=user_agent
|
endpoint=endpoint, credential=credential, user_agent=user_agent
|
||||||
)
|
)
|
||||||
@ -253,6 +265,7 @@ class AzureSearch(VectorStore):
|
|||||||
self,
|
self,
|
||||||
azure_search_endpoint: str,
|
azure_search_endpoint: str,
|
||||||
azure_search_key: str,
|
azure_search_key: str,
|
||||||
|
azure_ad_access_token: Optional[str],
|
||||||
index_name: str,
|
index_name: str,
|
||||||
embedding_function: Union[Callable, Embeddings],
|
embedding_function: Union[Callable, Embeddings],
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
@ -321,8 +334,9 @@ class AzureSearch(VectorStore):
|
|||||||
user_agent += " " + kwargs["user_agent"]
|
user_agent += " " + kwargs["user_agent"]
|
||||||
self.client = _get_search_client(
|
self.client = _get_search_client(
|
||||||
azure_search_endpoint,
|
azure_search_endpoint,
|
||||||
azure_search_key,
|
|
||||||
index_name,
|
index_name,
|
||||||
|
azure_search_key,
|
||||||
|
azure_ad_access_token,
|
||||||
semantic_configuration_name=semantic_configuration_name,
|
semantic_configuration_name=semantic_configuration_name,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
vector_search=vector_search,
|
vector_search=vector_search,
|
||||||
@ -336,8 +350,9 @@ class AzureSearch(VectorStore):
|
|||||||
)
|
)
|
||||||
self.async_client = _get_search_client(
|
self.async_client = _get_search_client(
|
||||||
azure_search_endpoint,
|
azure_search_endpoint,
|
||||||
azure_search_key,
|
|
||||||
index_name,
|
index_name,
|
||||||
|
azure_search_key,
|
||||||
|
azure_ad_access_token,
|
||||||
semantic_configuration_name=semantic_configuration_name,
|
semantic_configuration_name=semantic_configuration_name,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
vector_search=vector_search,
|
vector_search=vector_search,
|
||||||
@ -1387,6 +1402,7 @@ class AzureSearch(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
azure_search_endpoint: str = "",
|
azure_search_endpoint: str = "",
|
||||||
azure_search_key: str = "",
|
azure_search_key: str = "",
|
||||||
|
azure_ad_access_token: Optional[str] = None,
|
||||||
index_name: str = "langchain-index",
|
index_name: str = "langchain-index",
|
||||||
fields: Optional[List[SearchField]] = None,
|
fields: Optional[List[SearchField]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -1395,6 +1411,7 @@ class AzureSearch(VectorStore):
|
|||||||
azure_search = cls(
|
azure_search = cls(
|
||||||
azure_search_endpoint,
|
azure_search_endpoint,
|
||||||
azure_search_key,
|
azure_search_key,
|
||||||
|
azure_ad_access_token,
|
||||||
index_name,
|
index_name,
|
||||||
embedding,
|
embedding,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
@ -1411,6 +1428,7 @@ class AzureSearch(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
azure_search_endpoint: str = "",
|
azure_search_endpoint: str = "",
|
||||||
azure_search_key: str = "",
|
azure_search_key: str = "",
|
||||||
|
azure_ad_access_token: Optional[str] = None,
|
||||||
index_name: str = "langchain-index",
|
index_name: str = "langchain-index",
|
||||||
fields: Optional[List[SearchField]] = None,
|
fields: Optional[List[SearchField]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -1419,6 +1437,7 @@ class AzureSearch(VectorStore):
|
|||||||
azure_search = cls(
|
azure_search = cls(
|
||||||
azure_search_endpoint,
|
azure_search_endpoint,
|
||||||
azure_search_key,
|
azure_search_key,
|
||||||
|
azure_ad_access_token,
|
||||||
index_name,
|
index_name,
|
||||||
embedding,
|
embedding,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
|
@ -13,6 +13,7 @@ model = os.getenv("OPENAI_EMBEDDINGS_ENGINE_DOC", "text-embedding-ada-002")
|
|||||||
# Vector store settings
|
# Vector store settings
|
||||||
vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "")
|
vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "")
|
||||||
vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "")
|
vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "")
|
||||||
|
access_token: str = os.getenv("AZURE_SEARCH_ACCESS_TOKEN", "")
|
||||||
index_name: str = "embeddings-vector-store-test"
|
index_name: str = "embeddings-vector-store-test"
|
||||||
|
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ def similarity_search_test() -> None:
|
|||||||
vector_store: AzureSearch = AzureSearch(
|
vector_store: AzureSearch = AzureSearch(
|
||||||
azure_search_endpoint=vector_store_address,
|
azure_search_endpoint=vector_store_address,
|
||||||
azure_search_key=vector_store_password,
|
azure_search_key=vector_store_password,
|
||||||
|
azure_ad_access_token=access_token,
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
embedding_function=embeddings.embed_query,
|
embedding_function=embeddings.embed_query,
|
||||||
)
|
)
|
||||||
@ -68,6 +70,7 @@ def test_semantic_hybrid_search() -> None:
|
|||||||
vector_store: AzureSearch = AzureSearch(
|
vector_store: AzureSearch = AzureSearch(
|
||||||
azure_search_endpoint=vector_store_address,
|
azure_search_endpoint=vector_store_address,
|
||||||
azure_search_key=vector_store_password,
|
azure_search_key=vector_store_password,
|
||||||
|
azure_ad_access_token=access_token,
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
embedding_function=embeddings.embed_query,
|
embedding_function=embeddings.embed_query,
|
||||||
semantic_configuration_name="default",
|
semantic_configuration_name="default",
|
||||||
|
@ -32,6 +32,7 @@ class FakeEmbeddingsWithDimension(FakeEmbeddings):
|
|||||||
DEFAULT_INDEX_NAME = "langchain-index"
|
DEFAULT_INDEX_NAME = "langchain-index"
|
||||||
DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net"
|
DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net"
|
||||||
DEFAULT_KEY = "mykey"
|
DEFAULT_KEY = "mykey"
|
||||||
|
DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
|
||||||
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
|
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
|
||||||
|
|
||||||
|
|
||||||
@ -127,6 +128,7 @@ def create_vector_store(
|
|||||||
return AzureSearch(
|
return AzureSearch(
|
||||||
azure_search_endpoint=DEFAULT_ENDPOINT,
|
azure_search_endpoint=DEFAULT_ENDPOINT,
|
||||||
azure_search_key=DEFAULT_KEY,
|
azure_search_key=DEFAULT_KEY,
|
||||||
|
azure_ad_access_token=DEFAULT_ACCESS_TOKEN,
|
||||||
index_name=DEFAULT_INDEX_NAME,
|
index_name=DEFAULT_INDEX_NAME,
|
||||||
embedding_function=DEFAULT_EMBEDDING_MODEL,
|
embedding_function=DEFAULT_EMBEDDING_MODEL,
|
||||||
additional_search_client_options=additional_search_client_options,
|
additional_search_client_options=additional_search_client_options,
|
||||||
|
Loading…
Reference in New Issue
Block a user