mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
community: add top_k as param to Needle Retriever (#29821)
Thank you for contributing to LangChain! - [X] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: This PR adds top_k as a param to the Needle Retriever. By default we use top 10. - [X] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [X] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
parent
8147679169
commit
60f58df5b3
@ -23,6 +23,7 @@ class NeedleRetriever(BaseRetriever, BaseModel):
|
||||
- `needle_api_key` (Optional[str]): The API key for authenticating with Needle.
|
||||
- `collection_id` (str): The ID of the Needle collection to search in.
|
||||
- `client` (Optional[NeedleClient]): An optional instance of the NeedleClient.
|
||||
- `top_k` (Optional[int]): Maximum number of results to return.
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
@ -31,7 +32,8 @@ class NeedleRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="your-api-key",
|
||||
collection_id="your-collection-id"
|
||||
collection_id="your-collection-id",
|
||||
top_k=10 # optional
|
||||
)
|
||||
|
||||
results = retriever.retrieve("example query")
|
||||
@ -45,6 +47,9 @@ class NeedleRetriever(BaseRetriever, BaseModel):
|
||||
collection_id: Optional[str] = Field(
|
||||
..., description="The ID of the Needle collection to search in"
|
||||
)
|
||||
top_k: Optional[int] = Field(
|
||||
default=None, description="Maximum number of search results to return"
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
"""
|
||||
@ -75,7 +80,7 @@ class NeedleRetriever(BaseRetriever, BaseModel):
|
||||
raise ValueError("NeedleClient is not initialized. Provide an API key.")
|
||||
|
||||
results = self.client.collections.search(
|
||||
collection_id=self.collection_id, text=query
|
||||
collection_id=self.collection_id, text=query, top_k=self.top_k
|
||||
)
|
||||
docs = [Document(page_content=result.content) for result in results]
|
||||
return docs
|
||||
|
@ -17,11 +17,14 @@ class MockNeedleClient:
|
||||
self.collections = self.MockCollections()
|
||||
|
||||
class MockCollections:
|
||||
def search(self, collection_id: str, text: str) -> list[MockSearchResult]:
|
||||
return [
|
||||
def search(
|
||||
self, collection_id: str, text: str, top_k: int = 10
|
||||
) -> list[MockSearchResult]:
|
||||
results = [
|
||||
MockSearchResult(content=f"Result for query: {text}"),
|
||||
MockSearchResult(content=f"Another result for query: {text}"),
|
||||
]
|
||||
return results[:top_k]
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
|
Loading…
Reference in New Issue
Block a user