mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 15:38:06 +00:00
community(azuresearch): allow to use any valid credential (#28873)
Add option to use any valid credential type. Differentiates async cases needed by Azure Search. This could replace the use of a static token
This commit is contained in:
parent
4b4d09f82b
commit
8d9907088b
@ -42,6 +42,8 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from azure.core.credentials import TokenCredential
|
||||||
|
from azure.core.credentials_async import AsyncTokenCredential
|
||||||
from azure.search.documents import SearchClient, SearchItemPaged
|
from azure.search.documents import SearchClient, SearchItemPaged
|
||||||
from azure.search.documents.aio import (
|
from azure.search.documents.aio import (
|
||||||
AsyncSearchItemPaged,
|
AsyncSearchItemPaged,
|
||||||
@ -96,10 +98,13 @@ def _get_search_client(
|
|||||||
cors_options: Optional[CorsOptions] = None,
|
cors_options: Optional[CorsOptions] = None,
|
||||||
async_: bool = False,
|
async_: bool = False,
|
||||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||||
|
azure_credential: Optional[TokenCredential] = None,
|
||||||
|
azure_async_credential: Optional[AsyncTokenCredential] = None,
|
||||||
) -> Union[SearchClient, AsyncSearchClient]:
|
) -> Union[SearchClient, AsyncSearchClient]:
|
||||||
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
|
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
|
||||||
|
from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential
|
||||||
from azure.search.documents import SearchClient
|
from azure.search.documents import SearchClient
|
||||||
from azure.search.documents.aio import SearchClient as AsyncSearchClient
|
from azure.search.documents.aio import SearchClient as AsyncSearchClient
|
||||||
from azure.search.documents.indexes import SearchIndexClient
|
from azure.search.documents.indexes import SearchIndexClient
|
||||||
@ -143,12 +148,17 @@ def _get_search_client(
|
|||||||
if key.upper() == "INTERACTIVE":
|
if key.upper() == "INTERACTIVE":
|
||||||
credential = InteractiveBrowserCredential()
|
credential = InteractiveBrowserCredential()
|
||||||
credential.get_token("https://search.azure.com/.default")
|
credential.get_token("https://search.azure.com/.default")
|
||||||
|
async_credential = credential
|
||||||
else:
|
else:
|
||||||
credential = AzureKeyCredential(key)
|
credential = AzureKeyCredential(key)
|
||||||
|
async_credential = credential
|
||||||
elif azure_ad_access_token is not None:
|
elif azure_ad_access_token is not None:
|
||||||
credential = AzureBearerTokenCredential(azure_ad_access_token)
|
credential = AzureBearerTokenCredential(azure_ad_access_token)
|
||||||
|
async_credential = credential
|
||||||
else:
|
else:
|
||||||
credential = DefaultAzureCredential()
|
credential = azure_credential or DefaultAzureCredential()
|
||||||
|
async_credential = azure_async_credential or AsyncDefaultAzureCredential()
|
||||||
|
|
||||||
index_client: SearchIndexClient = SearchIndexClient(
|
index_client: SearchIndexClient = SearchIndexClient(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
credential=credential,
|
credential=credential,
|
||||||
@ -266,7 +276,7 @@ def _get_search_client(
|
|||||||
return AsyncSearchClient(
|
return AsyncSearchClient(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
credential=credential,
|
credential=async_credential,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
**additional_search_client_options,
|
**additional_search_client_options,
|
||||||
)
|
)
|
||||||
@ -278,7 +288,7 @@ class AzureSearch(VectorStore):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
azure_search_endpoint: str,
|
azure_search_endpoint: str,
|
||||||
azure_search_key: str,
|
azure_search_key: Optional[str],
|
||||||
index_name: str,
|
index_name: str,
|
||||||
embedding_function: Union[Callable, Embeddings],
|
embedding_function: Union[Callable, Embeddings],
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
@ -295,6 +305,8 @@ class AzureSearch(VectorStore):
|
|||||||
vector_search_dimensions: Optional[int] = None,
|
vector_search_dimensions: Optional[int] = None,
|
||||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||||
azure_ad_access_token: Optional[str] = None,
|
azure_ad_access_token: Optional[str] = None,
|
||||||
|
azure_credential: Optional[TokenCredential] = None,
|
||||||
|
azure_async_credential: Optional[AsyncTokenCredential] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@ -361,6 +373,7 @@ class AzureSearch(VectorStore):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
cors_options=cors_options,
|
cors_options=cors_options,
|
||||||
additional_search_client_options=additional_search_client_options,
|
additional_search_client_options=additional_search_client_options,
|
||||||
|
azure_credential=azure_credential,
|
||||||
)
|
)
|
||||||
self.async_client = _get_search_client(
|
self.async_client = _get_search_client(
|
||||||
azure_search_endpoint,
|
azure_search_endpoint,
|
||||||
@ -377,6 +390,8 @@ class AzureSearch(VectorStore):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
cors_options=cors_options,
|
cors_options=cors_options,
|
||||||
async_=True,
|
async_=True,
|
||||||
|
azure_credential=azure_credential,
|
||||||
|
azure_async_credential=azure_async_credential,
|
||||||
)
|
)
|
||||||
self.search_type = search_type
|
self.search_type = search_type
|
||||||
self.semantic_configuration_name = semantic_configuration_name
|
self.semantic_configuration_name = semantic_configuration_name
|
||||||
|
Loading…
Reference in New Issue
Block a user