community: add top_k as param to Needle Retriever (#29821)

Thank you for contributing to LangChain!

- [X] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core, etc. is
being modified. Use "docs: ..." for purely docs changes, "infra: ..."
for CI changes.
  - Example: "community: add foobar LLM"


- [x] **PR message**: 
This PR adds top_k as a param to the Needle Retriever. By default we use
top 10.



- [X] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [X] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
Jan Heimes 2025-02-16 14:30:52 +01:00 committed by GitHub
parent 8147679169
commit 60f58df5b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -23,6 +23,7 @@ class NeedleRetriever(BaseRetriever, BaseModel):
- `needle_api_key` (Optional[str]): The API key for authenticating with Needle.
- `collection_id` (str): The ID of the Needle collection to search in.
- `client` (Optional[NeedleClient]): An optional instance of the NeedleClient.
- `top_k` (Optional[int]): Maximum number of results to return.
Usage:
.. code-block:: python
@ -31,7 +32,8 @@ class NeedleRetriever(BaseRetriever, BaseModel):
retriever = NeedleRetriever(
needle_api_key="your-api-key",
collection_id="your-collection-id"
collection_id="your-collection-id",
top_k=10 # optional
)
results = retriever.retrieve("example query")
@ -45,6 +47,9 @@ class NeedleRetriever(BaseRetriever, BaseModel):
collection_id: Optional[str] = Field(
..., description="The ID of the Needle collection to search in"
)
top_k: Optional[int] = Field(
default=None, description="Maximum number of search results to return"
)
def _initialize_client(self) -> None:
"""
@ -75,7 +80,7 @@ class NeedleRetriever(BaseRetriever, BaseModel):
raise ValueError("NeedleClient is not initialized. Provide an API key.")
results = self.client.collections.search(
collection_id=self.collection_id, text=query
collection_id=self.collection_id, text=query, top_k=self.top_k
)
docs = [Document(page_content=result.content) for result in results]
return docs

View File

@ -17,11 +17,14 @@ class MockNeedleClient:
self.collections = self.MockCollections()
class MockCollections:
def search(self, collection_id: str, text: str) -> list[MockSearchResult]:
return [
def search(
self, collection_id: str, text: str, top_k: int = 10
) -> list[MockSearchResult]:
results = [
MockSearchResult(content=f"Result for query: {text}"),
MockSearchResult(content=f"Another result for query: {text}"),
]
return results[:top_k]
@pytest.mark.requires("needle")