Community: Updating Azure Retriever and Docs to be Azure AI Search instead of Azure Cognitive Search (#19925)

Last year Microsoft [changed the
name](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search)
of Azure Cognitive Search to Azure AI Search. This PR updates the
Langchain Azure Retriever API and it's associated docs to reflect this
change. It may be confusing for users to see the name Cognitive here and
AI in the Microsoft documentation which is why this is needed. I've also
added a more detailed example to the Azure retriever doc page.

There are more places that need a similar update but I'm breaking it up
so the PRs are not too big 😄 Fixing my errors from the previous PR.

Twitter: @marlene_zw

Two new tests added to test backward compatibility in
`libs/community/tests/integration_tests/retrievers/test_azure_cognitive_search.py`

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Marlene
2024-04-08 16:12:41 +01:00
committed by GitHub
parent 820b713086
commit 2f03bc397e
14 changed files with 402 additions and 209 deletions

View File

@@ -26,7 +26,8 @@ _module_lookup = {
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers.bedrock",
"ArceeRetriever": "langchain_community.retrievers.arcee",
"ArxivRetriever": "langchain_community.retrievers.arxiv",
"AzureCognitiveSearchRetriever": "langchain_community.retrievers.azure_cognitive_search", # noqa: E501
"AzureAISearchRetriever": "langchain_community.retrievers.azure_ai_search", # noqa: E501
"AzureCognitiveSearchRetriever": "langchain_community.retrievers.azure_ai_search", # noqa: E501
"BM25Retriever": "langchain_community.retrievers.bm25",
"BreebsRetriever": "langchain_community.retrievers.breebs",
"ChaindeskRetriever": "langchain_community.retrievers.chaindesk",

View File

@@ -18,13 +18,13 @@ DEFAULT_URL_SUFFIX = "search.windows.net"
"""Default URL Suffix for endpoint connection - commercial cloud"""
class AzureCognitiveSearchRetriever(BaseRetriever):
"""`Azure Cognitive Search` service retriever."""
class AzureAISearchRetriever(BaseRetriever):
"""`Azure AI Search` service retriever."""
service_name: str = ""
"""Name of Azure Cognitive Search service"""
"""Name of Azure AI Search service"""
index_name: str = ""
"""Name of Index inside Azure Cognitive Search service"""
"""Name of Index inside Azure AI Search service"""
api_key: str = ""
"""API Key. Both Admin and Query keys work, but for reading data it's
recommended to use a Query key."""
@@ -45,27 +45,30 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that service name, index name and api key exists in environment."""
values["service_name"] = get_from_dict_or_env(
values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME"
values, "service_name", "AZURE_AI_SEARCH_SERVICE_NAME"
)
values["index_name"] = get_from_dict_or_env(
values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME"
values, "index_name", "AZURE_AI_SEARCH_INDEX_NAME"
)
values["api_key"] = get_from_dict_or_env(
values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY"
values, "api_key", "AZURE_AI_SEARCH_API_KEY"
)
return values
def _build_search_url(self, query: str) -> str:
url_suffix = get_from_env(
"", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX
)
url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
if url_suffix in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}/"
elif url_suffix in self.service_name and "https://" not in self.service_name:
base_url = f"https://{self.service_name}/"
elif url_suffix not in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}.{url_suffix}/"
elif (
url_suffix not in self.service_name and "https://" not in self.service_name
):
base_url = f"https://{self.service_name}.{url_suffix}/"
else:
# pass to Azure to throw a specific error
base_url = self.service_name
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
top_param = f"&$top={self.top_k}" if self.top_k else ""
@@ -119,3 +122,11 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
Document(page_content=result.pop(self.content_key), metadata=result)
for result in search_results
]
# For backwards compatibility
class AzureCognitiveSearchRetriever(AzureAISearchRetriever):
"""`Azure Cognitive Search` service retriever.
This version of the retriever will soon be
depreciated. Please switch to AzureAISearchRetriever
"""

View File

@@ -0,0 +1,70 @@
"""Test Azure AI Search wrapper."""
from langchain_core.documents import Document
from langchain_community.retrievers.azure_ai_search import (
AzureAISearchRetriever,
AzureCognitiveSearchRetriever,
)
def test_azure_ai_search_get_relevant_documents() -> None:
"""Test valid call to Azure AI Search.
In order to run this test, you should provide
a `service_name`, azure search `api_key` and an `index_name`
as arguments for the AzureAISearchRetriever in both tests.
api_version, aiosession and topk_k are optional parameters.
"""
retriever = AzureAISearchRetriever()
documents = retriever.get_relevant_documents("what is langchain?")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
retriever = AzureAISearchRetriever(top_k=1)
documents = retriever.get_relevant_documents("what is langchain?")
assert len(documents) <= 1
async def test_azure_ai_search_aget_relevant_documents() -> None:
"""Test valid async call to Azure AI Search.
In order to run this test, you should provide
a `service_name`, azure search `api_key` and an `index_name`
as arguments for the AzureAISearchRetriever.
"""
retriever = AzureAISearchRetriever()
documents = await retriever.aget_relevant_documents("what is langchain?")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
def test_azure_cognitive_search_get_relevant_documents() -> None:
"""Test valid call to Azure Cognitive Search.
This is to test backwards compatibility of the retriever
"""
retriever = AzureCognitiveSearchRetriever()
documents = retriever.get_relevant_documents("what is langchain?")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
retriever = AzureCognitiveSearchRetriever(top_k=1)
documents = retriever.get_relevant_documents("what is langchain?")
assert len(documents) <= 1
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
"""Test valid async call to Azure Cognitive Search.
This is to test backwards compatibility of the retriever
"""
retriever = AzureCognitiveSearchRetriever()
documents = await retriever.aget_relevant_documents("what is langchain?")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content

View File

@@ -1,37 +0,0 @@
"""Test Azure Cognitive Search wrapper."""
from langchain_core.documents import Document
from langchain_community.retrievers.azure_cognitive_search import (
AzureCognitiveSearchRetriever,
)
def test_azure_cognitive_search_get_relevant_documents() -> None:
"""Test valid call to Azure Cognitive Search.
In order to run this test, you should provide a service name, azure search api key
and an index_name as arguments for the AzureCognitiveSearchRetriever in both tests.
"""
retriever = AzureCognitiveSearchRetriever()
documents = retriever.get_relevant_documents("what is langchain?")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
retriever = AzureCognitiveSearchRetriever()
documents = retriever.get_relevant_documents("what is langchain?")
assert len(documents) <= 1
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
"""Test valid async call to Azure Cognitive Search.
In order to run this test, you should provide a service name, azure search api key
and an index_name as arguments for the AzureCognitiveSearchRetriever.
"""
retriever = AzureCognitiveSearchRetriever()
documents = await retriever.aget_relevant_documents("what is langchain?")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content

View File

@@ -5,6 +5,7 @@ EXPECTED_ALL = [
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
"BreebsRetriever",
"ChatGPTPluginRetriever",

View File

@@ -60,6 +60,7 @@ __all__ = [
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
"ChatGPTPluginRetriever",
"ContextualCompressionRetriever",

View File

@@ -0,0 +1,6 @@
from langchain_community.retrievers.azure_ai_search import (
AzureAISearchRetriever,
AzureCognitiveSearchRetriever,
)
__all__ = ["AzureAISearchRetriever", "AzureCognitiveSearchRetriever"]

View File

@@ -1,5 +0,0 @@
from langchain_community.retrievers.azure_cognitive_search import (
AzureCognitiveSearchRetriever,
)
__all__ = ["AzureCognitiveSearchRetriever"]

View File

@@ -6,6 +6,7 @@ EXPECTED_ALL = [
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
"ChatGPTPluginRetriever",
"ContextualCompressionRetriever",