mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
Thank you for contributing to LangChain! - [X] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: This PR adds top_k as a param to the Needle Retriever. By default we use top 10. - [X] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [X] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
from typing import Any
|
|
|
|
import pytest
|
|
from pytest_mock import MockerFixture
|
|
|
|
|
|
# Mock class to simulate search results from Needle API
|
|
class MockSearchResult:
|
|
def __init__(self, content: str) -> None:
|
|
self.content = content
|
|
|
|
|
|
# Mock class to simulate NeedleClient and its collections behavior
|
|
class MockNeedleClient:
|
|
def __init__(self, api_key: str) -> None:
|
|
self.api_key = api_key
|
|
self.collections = self.MockCollections()
|
|
|
|
class MockCollections:
|
|
def search(
|
|
self, collection_id: str, text: str, top_k: int = 10
|
|
) -> list[MockSearchResult]:
|
|
results = [
|
|
MockSearchResult(content=f"Result for query: {text}"),
|
|
MockSearchResult(content=f"Another result for query: {text}"),
|
|
]
|
|
return results[:top_k]
|
|
|
|
|
|
@pytest.mark.requires("needle")
|
|
def test_needle_retriever_initialization() -> None:
|
|
"""
|
|
Test that the NeedleRetriever is initialized correctly.
|
|
"""
|
|
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
|
|
|
|
retriever = NeedleRetriever(
|
|
needle_api_key="mock_api_key",
|
|
collection_id="mock_collection_id",
|
|
)
|
|
|
|
assert retriever.needle_api_key == "mock_api_key"
|
|
assert retriever.collection_id == "mock_collection_id"
|
|
|
|
|
|
@pytest.mark.requires("needle")
|
|
def test_get_relevant_documents(mocker: MockerFixture) -> None:
|
|
"""
|
|
Test that the retriever correctly fetches documents.
|
|
"""
|
|
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
|
|
|
|
# Patch the actual NeedleClient import path used in the NeedleRetriever
|
|
mocker.patch("needle.v1.NeedleClient", new=MockNeedleClient)
|
|
|
|
# Initialize the retriever with mocked API key and collection ID
|
|
retriever = NeedleRetriever(
|
|
needle_api_key="mock_api_key",
|
|
collection_id="mock_collection_id",
|
|
)
|
|
|
|
mock_run_manager: Any = None
|
|
|
|
# Perform the search
|
|
query = "What is RAG?"
|
|
retrieved_documents = retriever._get_relevant_documents(
|
|
query, run_manager=mock_run_manager
|
|
)
|
|
|
|
# Validate the results
|
|
assert len(retrieved_documents) == 2
|
|
assert retrieved_documents[0].page_content == "Result for query: What is RAG?"
|
|
assert (
|
|
retrieved_documents[1].page_content == "Another result for query: What is RAG?"
|
|
)
|