community[patch]: Fixing incorrect base URLs for Azure Cognitive Search Retriever (#19352)

This PR adds code to make sure that the correct base URL is being
created for the Azure Cognitive Search retriever. At the moment an
incorrect base URL is being generated. I think this is happening because
the original code was based on a depreciated API version. No
dependencies need to be added. I've also added more context to the test
doc strings.

I should also note that ACS is now Azure AI Search. I will open a
separate PR to make these changes as that would be a breaking change and
should potentially be discussed.

Twitter: @marlene_zw



- No new tests added, however the current ACS retriever tests are now
passing when I run them.
- Code was linted.

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Marlene 2024-03-26 00:04:59 +00:00 committed by GitHub
parent d667b1ea8f
commit f1313339ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 8 deletions

View File

@ -28,7 +28,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
api_key: str = "" api_key: str = ""
"""API Key. Both Admin and Query keys work, but for reading data it's """API Key. Both Admin and Query keys work, but for reading data it's
recommended to use a Query key.""" recommended to use a Query key."""
api_version: str = "2020-06-30" api_version: str = "2023-11-01"
"""API version""" """API version"""
aiosession: Optional[aiohttp.ClientSession] = None aiosession: Optional[aiohttp.ClientSession] = None
"""ClientSession, in case we want to reuse connection for better performance.""" """ClientSession, in case we want to reuse connection for better performance."""
@ -59,7 +59,14 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
url_suffix = get_from_env( url_suffix = get_from_env(
"", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX "", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX
) )
base_url = f"https://{self.service_name}.{url_suffix}/" if url_suffix in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}/"
elif url_suffix in self.service_name and "https://" not in self.service_name:
base_url = f"https://{self.service_name}/"
elif url_suffix not in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}.{url_suffix}/"
else:
base_url = self.service_name
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
top_param = f"&$top={self.top_k}" if self.top_k else "" top_param = f"&$top={self.top_k}" if self.top_k else ""
return base_url + endpoint_path + f"&search={query}" + top_param return base_url + endpoint_path + f"&search={query}" + top_param

View File

@ -7,22 +7,31 @@ from langchain_community.retrievers.azure_cognitive_search import (
def test_azure_cognitive_search_get_relevant_documents() -> None: def test_azure_cognitive_search_get_relevant_documents() -> None:
"""Test valid call to Azure Cognitive Search.""" """Test valid call to Azure Cognitive Search.
In order to run this test, you should provide a service name, azure search api key
and an index_name as arguments for the AzureCognitiveSearchRetriever in both tests.
"""
retriever = AzureCognitiveSearchRetriever() retriever = AzureCognitiveSearchRetriever()
documents = retriever.get_relevant_documents("what is langchain")
documents = retriever.get_relevant_documents("what is langchain?")
for doc in documents: for doc in documents:
assert isinstance(doc, Document) assert isinstance(doc, Document)
assert doc.page_content assert doc.page_content
retriever = AzureCognitiveSearchRetriever(top_k=1) retriever = AzureCognitiveSearchRetriever()
documents = retriever.get_relevant_documents("what is langchain") documents = retriever.get_relevant_documents("what is langchain?")
assert len(documents) <= 1 assert len(documents) <= 1
async def test_azure_cognitive_search_aget_relevant_documents() -> None: async def test_azure_cognitive_search_aget_relevant_documents() -> None:
"""Test valid async call to Azure Cognitive Search.""" """Test valid async call to Azure Cognitive Search.
In order to run this test, you should provide a service name, azure search api key
and an index_name as arguments for the AzureCognitiveSearchRetriever.
"""
retriever = AzureCognitiveSearchRetriever() retriever = AzureCognitiveSearchRetriever()
documents = await retriever.aget_relevant_documents("what is langchain") documents = await retriever.aget_relevant_documents("what is langchain?")
for doc in documents: for doc in documents:
assert isinstance(doc, Document) assert isinstance(doc, Document)
assert doc.page_content assert doc.page_content