mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
community: add Needle retriever and document loader integration (#28157)
- [x] **PR title**: "community: add Needle retriever and document loader integration" - Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** This PR adds a new integration for Needle, which includes: - **NeedleRetriever**: A retriever for fetching documents from Needle collections. - **NeedleLoader**: A document loader for managing and loading documents into Needle collections. - Example notebooks demonstrating usage have been added in: - `docs/docs/integrations/retrievers/needle.ipynb` - `docs/docs/integrations/document_loaders/needle.ipynb`. - **Dependencies:** The `needle-python` package is required as an external dependency for accessing Needle's API. It has been added to the extended testing dependencies list. - **Twitter handle:** Feel free to mention me if this PR gets announced: [needlexai](https://x.com/NeedlexAI). - [x] **Add tests and docs**: If you're adding a new integration, please include 1. Unit tests have been added for both `NeedleRetriever` and `NeedleLoader` in `libs/community/tests/unit_tests`. These tests mock API calls to avoid relying on network access. 2. Example notebooks have been added to `docs/docs/integrations/`, showcasing both retriever and loader functionality. - [x] **Lint and test**: Run `make format`, `make lint`, and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ - `make format`: Passed - `make lint`: Passed - `make test`: Passed (requires `needle-python` to be installed locally; this package is not added to LangChain dependencies). Additional guidelines: - [x] Optional dependencies are imported only within functions. - [x] No dependencies have been added to pyproject.toml files except for those required for unit tests. - [x] The PR does not touch more than one package. - [x] Changes are fully backwards compatible. - [x] Community additions are not re-imported into LangChain core. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -105,6 +105,7 @@ EXPECTED_ALL = [
|
||||
"MergedDataLoader",
|
||||
"ModernTreasuryLoader",
|
||||
"MongodbLoader",
|
||||
"NeedleLoader",
|
||||
"NewsURLLoader",
|
||||
"NotebookLoader",
|
||||
"NotionDBLoader",
|
||||
|
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
def test_add_and_fetch_files(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test adding and fetching files using the NeedleLoader with a mock.
|
||||
"""
|
||||
from langchain_community.document_loaders.needle import NeedleLoader # noqa: I001
|
||||
from needle.v1.models import CollectionFile # noqa: I001
|
||||
|
||||
# Create mock instances using mocker
|
||||
# Create mock instances using mocker
|
||||
mock_files = mocker.Mock()
|
||||
mock_files.add.return_value = [
|
||||
CollectionFile(
|
||||
id="mock_id",
|
||||
name="tech-radar-30.pdf",
|
||||
url="https://example.com/",
|
||||
status="indexed",
|
||||
type="mock_type",
|
||||
user_id="mock_user_id",
|
||||
connector_id="mock_connector_id",
|
||||
size=1234,
|
||||
md5_hash="mock_md5_hash",
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
]
|
||||
mock_files.list.return_value = [
|
||||
CollectionFile(
|
||||
id="mock_id",
|
||||
name="tech-radar-30.pdf",
|
||||
url="https://example.com/",
|
||||
status="indexed",
|
||||
type="mock_type",
|
||||
user_id="mock_user_id",
|
||||
connector_id="mock_connector_id",
|
||||
size=1234,
|
||||
md5_hash="mock_md5_hash",
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
mock_collections = mocker.Mock()
|
||||
mock_collections.files = mock_files
|
||||
|
||||
mock_needle_client = mocker.Mock()
|
||||
mock_needle_client.collections = mock_collections
|
||||
|
||||
# Patch the NeedleClient to return the mock client
|
||||
mocker.patch("needle.v1.NeedleClient", return_value=mock_needle_client)
|
||||
|
||||
# Initialize NeedleLoader with mock API key and collection ID
|
||||
document_store = NeedleLoader(
|
||||
needle_api_key="fake_api_key",
|
||||
collection_id="fake_collection_id",
|
||||
)
|
||||
|
||||
# Define files to add
|
||||
files = {
|
||||
"tech-radar-30.pdf": "https://www.thoughtworks.com/content/dam/thoughtworks/documents/radar/2024/04/tr_technology_radar_vol_30_en.pdf"
|
||||
}
|
||||
|
||||
# Add files to the collection using the mock client
|
||||
document_store.add_files(files=files)
|
||||
|
||||
# Fetch the added files using the mock client
|
||||
added_files = document_store._fetch_documents()
|
||||
|
||||
# Assertions to verify that the file was added and fetched correctly
|
||||
assert isinstance(added_files[0].metadata["title"], str)
|
||||
assert isinstance(added_files[0].metadata["source"], str)
|
@@ -26,6 +26,7 @@ EXPECTED_ALL = [
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"NanoPQRetriever",
|
||||
"NeedleRetriever",
|
||||
"OutlineRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"PubMedRetriever",
|
||||
|
72
libs/community/tests/unit_tests/retrievers/test_needle.py
Normal file
72
libs/community/tests/unit_tests/retrievers/test_needle.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
# Mock class to simulate search results from Needle API
|
||||
class MockSearchResult:
|
||||
def __init__(self, content: str) -> None:
|
||||
self.content = content
|
||||
|
||||
|
||||
# Mock class to simulate NeedleClient and its collections behavior
|
||||
class MockNeedleClient:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.collections = self.MockCollections()
|
||||
|
||||
class MockCollections:
|
||||
def search(self, collection_id: str, text: str) -> list[MockSearchResult]:
|
||||
return [
|
||||
MockSearchResult(content=f"Result for query: {text}"),
|
||||
MockSearchResult(content=f"Another result for query: {text}"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
def test_needle_retriever_initialization() -> None:
|
||||
"""
|
||||
Test that the NeedleRetriever is initialized correctly.
|
||||
"""
|
||||
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
|
||||
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="mock_api_key",
|
||||
collection_id="mock_collection_id",
|
||||
)
|
||||
|
||||
assert retriever.needle_api_key == "mock_api_key"
|
||||
assert retriever.collection_id == "mock_collection_id"
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
def test_get_relevant_documents(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that the retriever correctly fetches documents.
|
||||
"""
|
||||
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
|
||||
|
||||
# Patch the actual NeedleClient import path used in the NeedleRetriever
|
||||
mocker.patch("needle.v1.NeedleClient", new=MockNeedleClient)
|
||||
|
||||
# Initialize the retriever with mocked API key and collection ID
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="mock_api_key",
|
||||
collection_id="mock_collection_id",
|
||||
)
|
||||
|
||||
mock_run_manager: Any = None
|
||||
|
||||
# Perform the search
|
||||
query = "What is RAG?"
|
||||
retrieved_documents = retriever._get_relevant_documents(
|
||||
query, run_manager=mock_run_manager
|
||||
)
|
||||
|
||||
# Validate the results
|
||||
assert len(retrieved_documents) == 2
|
||||
assert retrieved_documents[0].page_content == "Result for query: What is RAG?"
|
||||
assert (
|
||||
retrieved_documents[1].page_content == "Another result for query: What is RAG?"
|
||||
)
|
Reference in New Issue
Block a user