mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
[Small upgrade] Allow document limit in AzureCognitiveSearchRetriever (#7690)
Multiple people have asked in #5081 for a way to limit the documents returned from an AzureCognitiveSearchRetriever. This PR adds the `top_n` parameter to allow that. Twitter handle: [@UmerHAdil](twitter.com/umerHAdil)
This commit is contained in:
parent
af6d333147
commit
82f3e32d8d
@ -91,7 +91,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\")"
|
||||
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\", top_k=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -111,6 +111,36 @@
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"what is langchain\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "72eca08e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can change the number of results returned with the `top_k` parameter. The default value is `None`, which returns all results. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "097146c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6d9963f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dc120696",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -33,6 +33,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||
content_key: str = "content"
|
||||
"""Key in a retrieved result to set as the Document page_content."""
|
||||
top_k: Optional[int] = None
|
||||
"""Number of results to retrieve. Set to None to retrieve all results."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
@ -55,7 +57,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||
def _build_search_url(self, query: str) -> str:
|
||||
base_url = f"https://{self.service_name}.search.windows.net/"
|
||||
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
|
||||
return base_url + endpoint_path + f"&search={query}"
|
||||
top_param = f"&$top={self.top_k}" if self.top_k else ""
|
||||
return base_url + endpoint_path + f"&search={query}" + top_param
|
||||
|
||||
@property
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
|
@ -13,6 +13,10 @@ def test_azure_cognitive_search_get_relevant_documents() -> None:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
||||
|
||||
retriever = AzureCognitiveSearchRetriever(top_k=1)
|
||||
documents = retriever.get_relevant_documents("what is langchain")
|
||||
assert len(documents) <= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user