From 60f58df5b31e7959751438af8242bb178b8ce041 Mon Sep 17 00:00:00 2001 From: Jan Heimes <45521680+JANHMS@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:30:52 +0100 Subject: [PATCH] community: add top_k as param to Needle Retriever (#29821) Thank you for contributing to LangChain! - [X] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: This PR adds top_k as a param to the Needle Retriever. By default we use top 10. - [X] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [X] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --- libs/community/langchain_community/retrievers/needle.py | 9 +++++++-- .../community/tests/unit_tests/retrievers/test_needle.py | 7 +++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/retrievers/needle.py b/libs/community/langchain_community/retrievers/needle.py index 201d617fda4..52a5245108d 100644 --- a/libs/community/langchain_community/retrievers/needle.py +++ b/libs/community/langchain_community/retrievers/needle.py @@ -23,6 +23,7 @@ class NeedleRetriever(BaseRetriever, BaseModel): - `needle_api_key` (Optional[str]): The API key for authenticating with Needle. - `collection_id` (str): The ID of the Needle collection to search in. - `client` (Optional[NeedleClient]): An optional instance of the NeedleClient. + - `top_k` (Optional[int]): Maximum number of results to return. Usage: .. code-block:: python @@ -31,7 +32,8 @@ class NeedleRetriever(BaseRetriever, BaseModel): retriever = NeedleRetriever( needle_api_key="your-api-key", - collection_id="your-collection-id" + collection_id="your-collection-id", + top_k=10 # optional ) results = retriever.retrieve("example query") @@ -45,6 +47,9 @@ class NeedleRetriever(BaseRetriever, BaseModel): collection_id: Optional[str] = Field( ..., description="The ID of the Needle collection to search in" ) + top_k: Optional[int] = Field( + default=None, description="Maximum number of search results to return" + ) def _initialize_client(self) -> None: """ @@ -75,7 +80,7 @@ class NeedleRetriever(BaseRetriever, BaseModel): raise ValueError("NeedleClient is not initialized. Provide an API key.") results = self.client.collections.search( - collection_id=self.collection_id, text=query + collection_id=self.collection_id, text=query, top_k=self.top_k ) docs = [Document(page_content=result.content) for result in results] return docs diff --git a/libs/community/tests/unit_tests/retrievers/test_needle.py b/libs/community/tests/unit_tests/retrievers/test_needle.py index 853250d409f..1bfb9243af7 100644 --- a/libs/community/tests/unit_tests/retrievers/test_needle.py +++ b/libs/community/tests/unit_tests/retrievers/test_needle.py @@ -17,11 +17,14 @@ class MockNeedleClient: self.collections = self.MockCollections() class MockCollections: - def search(self, collection_id: str, text: str) -> list[MockSearchResult]: - return [ + def search( + self, collection_id: str, text: str, top_k: int = 10 + ) -> list[MockSearchResult]: + results = [ MockSearchResult(content=f"Result for query: {text}"), MockSearchResult(content=f"Another result for query: {text}"), ] + return results[:top_k] @pytest.mark.requires("needle")