mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
Community: Updating Azure Retriever and Docs to be Azure AI Search instead of Azure Cognitive Search (#19925)
Last year Microsoft [changed the name](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search) of Azure Cognitive Search to Azure AI Search. This PR updates the Langchain Azure Retriever API and it's associated docs to reflect this change. It may be confusing for users to see the name Cognitive here and AI in the Microsoft documentation which is why this is needed. I've also added a more detailed example to the Azure retriever doc page. There are more places that need a similar update but I'm breaking it up so the PRs are not too big 😄 Fixing my errors from the previous PR. Twitter: @marlene_zw Two new tests added to test backward compatibility in `libs/community/tests/integration_tests/retrievers/test_azure_cognitive_search.py` --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -26,7 +26,8 @@ _module_lookup = {
|
||||
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers.bedrock",
|
||||
"ArceeRetriever": "langchain_community.retrievers.arcee",
|
||||
"ArxivRetriever": "langchain_community.retrievers.arxiv",
|
||||
"AzureCognitiveSearchRetriever": "langchain_community.retrievers.azure_cognitive_search", # noqa: E501
|
||||
"AzureAISearchRetriever": "langchain_community.retrievers.azure_ai_search", # noqa: E501
|
||||
"AzureCognitiveSearchRetriever": "langchain_community.retrievers.azure_ai_search", # noqa: E501
|
||||
"BM25Retriever": "langchain_community.retrievers.bm25",
|
||||
"BreebsRetriever": "langchain_community.retrievers.breebs",
|
||||
"ChaindeskRetriever": "langchain_community.retrievers.chaindesk",
|
||||
|
@@ -18,13 +18,13 @@ DEFAULT_URL_SUFFIX = "search.windows.net"
|
||||
"""Default URL Suffix for endpoint connection - commercial cloud"""
|
||||
|
||||
|
||||
class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||
"""`Azure Cognitive Search` service retriever."""
|
||||
class AzureAISearchRetriever(BaseRetriever):
|
||||
"""`Azure AI Search` service retriever."""
|
||||
|
||||
service_name: str = ""
|
||||
"""Name of Azure Cognitive Search service"""
|
||||
"""Name of Azure AI Search service"""
|
||||
index_name: str = ""
|
||||
"""Name of Index inside Azure Cognitive Search service"""
|
||||
"""Name of Index inside Azure AI Search service"""
|
||||
api_key: str = ""
|
||||
"""API Key. Both Admin and Query keys work, but for reading data it's
|
||||
recommended to use a Query key."""
|
||||
@@ -45,27 +45,30 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that service name, index name and api key exists in environment."""
|
||||
values["service_name"] = get_from_dict_or_env(
|
||||
values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME"
|
||||
values, "service_name", "AZURE_AI_SEARCH_SERVICE_NAME"
|
||||
)
|
||||
values["index_name"] = get_from_dict_or_env(
|
||||
values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME"
|
||||
values, "index_name", "AZURE_AI_SEARCH_INDEX_NAME"
|
||||
)
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY"
|
||||
values, "api_key", "AZURE_AI_SEARCH_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _build_search_url(self, query: str) -> str:
|
||||
url_suffix = get_from_env(
|
||||
"", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX
|
||||
)
|
||||
url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
|
||||
if url_suffix in self.service_name and "https://" in self.service_name:
|
||||
base_url = f"{self.service_name}/"
|
||||
elif url_suffix in self.service_name and "https://" not in self.service_name:
|
||||
base_url = f"https://{self.service_name}/"
|
||||
elif url_suffix not in self.service_name and "https://" in self.service_name:
|
||||
base_url = f"{self.service_name}.{url_suffix}/"
|
||||
elif (
|
||||
url_suffix not in self.service_name and "https://" not in self.service_name
|
||||
):
|
||||
base_url = f"https://{self.service_name}.{url_suffix}/"
|
||||
else:
|
||||
# pass to Azure to throw a specific error
|
||||
base_url = self.service_name
|
||||
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
|
||||
top_param = f"&$top={self.top_k}" if self.top_k else ""
|
||||
@@ -119,3 +122,11 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
class AzureCognitiveSearchRetriever(AzureAISearchRetriever):
|
||||
"""`Azure Cognitive Search` service retriever.
|
||||
This version of the retriever will soon be
|
||||
depreciated. Please switch to AzureAISearchRetriever
|
||||
"""
|
@@ -0,0 +1,70 @@
|
||||
"""Test Azure AI Search wrapper."""
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.retrievers.azure_ai_search import (
|
||||
AzureAISearchRetriever,
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
|
||||
|
||||
def test_azure_ai_search_get_relevant_documents() -> None:
|
||||
"""Test valid call to Azure AI Search.
|
||||
|
||||
In order to run this test, you should provide
|
||||
a `service_name`, azure search `api_key` and an `index_name`
|
||||
as arguments for the AzureAISearchRetriever in both tests.
|
||||
api_version, aiosession and topk_k are optional parameters.
|
||||
"""
|
||||
retriever = AzureAISearchRetriever()
|
||||
|
||||
documents = retriever.get_relevant_documents("what is langchain?")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
||||
|
||||
retriever = AzureAISearchRetriever(top_k=1)
|
||||
documents = retriever.get_relevant_documents("what is langchain?")
|
||||
assert len(documents) <= 1
|
||||
|
||||
|
||||
async def test_azure_ai_search_aget_relevant_documents() -> None:
|
||||
"""Test valid async call to Azure AI Search.
|
||||
|
||||
In order to run this test, you should provide
|
||||
a `service_name`, azure search `api_key` and an `index_name`
|
||||
as arguments for the AzureAISearchRetriever.
|
||||
"""
|
||||
retriever = AzureAISearchRetriever()
|
||||
documents = await retriever.aget_relevant_documents("what is langchain?")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
||||
|
||||
|
||||
def test_azure_cognitive_search_get_relevant_documents() -> None:
|
||||
"""Test valid call to Azure Cognitive Search.
|
||||
|
||||
This is to test backwards compatibility of the retriever
|
||||
"""
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
|
||||
documents = retriever.get_relevant_documents("what is langchain?")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
||||
|
||||
retriever = AzureCognitiveSearchRetriever(top_k=1)
|
||||
documents = retriever.get_relevant_documents("what is langchain?")
|
||||
assert len(documents) <= 1
|
||||
|
||||
|
||||
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
|
||||
"""Test valid async call to Azure Cognitive Search.
|
||||
|
||||
This is to test backwards compatibility of the retriever
|
||||
"""
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
documents = await retriever.aget_relevant_documents("what is langchain?")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
@@ -1,37 +0,0 @@
|
||||
"""Test Azure Cognitive Search wrapper."""
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.retrievers.azure_cognitive_search import (
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
|
||||
|
||||
def test_azure_cognitive_search_get_relevant_documents() -> None:
|
||||
"""Test valid call to Azure Cognitive Search.
|
||||
|
||||
In order to run this test, you should provide a service name, azure search api key
|
||||
and an index_name as arguments for the AzureCognitiveSearchRetriever in both tests.
|
||||
"""
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
|
||||
documents = retriever.get_relevant_documents("what is langchain?")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
||||
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
documents = retriever.get_relevant_documents("what is langchain?")
|
||||
assert len(documents) <= 1
|
||||
|
||||
|
||||
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
|
||||
"""Test valid async call to Azure Cognitive Search.
|
||||
|
||||
In order to run this test, you should provide a service name, azure search api key
|
||||
and an index_name as arguments for the AzureCognitiveSearchRetriever.
|
||||
"""
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
documents = await retriever.aget_relevant_documents("what is langchain?")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
@@ -5,6 +5,7 @@ EXPECTED_ALL = [
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureAISearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"BreebsRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
|
@@ -60,6 +60,7 @@ __all__ = [
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureAISearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"ContextualCompressionRetriever",
|
||||
|
6
libs/langchain/langchain/retrievers/azure_ai_search.py
Normal file
6
libs/langchain/langchain/retrievers/azure_ai_search.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from langchain_community.retrievers.azure_ai_search import (
|
||||
AzureAISearchRetriever,
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
|
||||
__all__ = ["AzureAISearchRetriever", "AzureCognitiveSearchRetriever"]
|
@@ -1,5 +0,0 @@
|
||||
from langchain_community.retrievers.azure_cognitive_search import (
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
|
||||
__all__ = ["AzureCognitiveSearchRetriever"]
|
@@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureAISearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"ContextualCompressionRetriever",
|
||||
|
Reference in New Issue
Block a user