community: add Needle retriever and document loader integration (#28157)

- [x] **PR title**: "community: add Needle retriever and document loader
integration"
- Where "package" is whichever of langchain, community, core, etc. is
being modified. Use "docs: ..." for purely docs changes, "infra: ..."
for CI changes.
  - Example: "community: add foobar LLM"

- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** This PR adds a new integration for Needle, which
includes:
- **NeedleRetriever**: A retriever for fetching documents from Needle
collections.
- **NeedleLoader**: A document loader for managing and loading documents
into Needle collections.
      - Example notebooks demonstrating usage have been added in:
        - `docs/docs/integrations/retrievers/needle.ipynb`
        - `docs/docs/integrations/document_loaders/needle.ipynb`.
- **Dependencies:** The `needle-python` package is required as an
external dependency for accessing Needle's API. It has been added to the
extended testing dependencies list.
- **Twitter handle:** Feel free to mention me if this PR gets announced:
[needlexai](https://x.com/NeedlexAI).

- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. Unit tests have been added for both `NeedleRetriever` and
`NeedleLoader` in `libs/community/tests/unit_tests`. These tests mock
API calls to avoid relying on network access.
2. Example notebooks have been added to `docs/docs/integrations/`,
showcasing both retriever and loader functionality.

- [x] **Lint and test**: Run `make format`, `make lint`, and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/
  - `make format`: Passed
  - `make lint`: Passed
- `make test`: Passed (requires `needle-python` to be installed locally;
this package is not added to LangChain dependencies).

Additional guidelines:
- [x] Optional dependencies are imported only within functions.
- [x] No dependencies have been added to pyproject.toml files except for
those required for unit tests.
- [x] The PR does not touch more than one package.
- [x] Changes are fully backwards compatible.
- [x] Community additions are not re-imported into LangChain core.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Jan Heimes
2024-12-03 23:06:25 +01:00
committed by GitHub
parent b0a83071df
commit ef365543cb
11 changed files with 906 additions and 0 deletions

View File

@@ -46,6 +46,7 @@ motor>=3.3.1,<4
msal>=1.25.0,<2
mwparserfromhell>=0.6.4,<0.7
mwxml>=0.3.3,<0.4
needle-python>=0.4
networkx>=3.2.1,<4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3

View File

@@ -299,6 +299,9 @@ if TYPE_CHECKING:
from langchain_community.document_loaders.mongodb import (
MongodbLoader,
)
from langchain_community.document_loaders.needle import (
NeedleLoader,
)
from langchain_community.document_loaders.news import (
NewsURLLoader,
)
@@ -631,6 +634,7 @@ _module_lookup = {
"MergedDataLoader": "langchain_community.document_loaders.merge",
"ModernTreasuryLoader": "langchain_community.document_loaders.modern_treasury",
"MongodbLoader": "langchain_community.document_loaders.mongodb",
"NeedleLoader": "langchain_community.document_loaders.needle",
"NewsURLLoader": "langchain_community.document_loaders.news",
"NotebookLoader": "langchain_community.document_loaders.notebook",
"NotionDBLoader": "langchain_community.document_loaders.notiondb",
@@ -837,6 +841,7 @@ __all__ = [
"MergedDataLoader",
"ModernTreasuryLoader",
"MongodbLoader",
"NeedleLoader",
"NewsURLLoader",
"NotebookLoader",
"NotionDBLoader",

View File

@@ -0,0 +1,164 @@
from typing import Dict, Iterator, List, Optional
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
class NeedleLoader(BaseLoader):
"""
NeedleLoader is a document loader for managing documents stored in a collection.
Setup:
Install the `needle-python` library and set your Needle API key.
.. code-block:: bash
pip install needle-python
export NEEDLE_API_KEY="your-api-key"
Key init args:
- `needle_api_key` (Optional[str]): API key for authenticating with Needle.
- `collection_id` (str): Needle collection to load documents from.
Usage:
.. code-block:: python
from langchain_community.document_loaders.needle import NeedleLoader
loader = NeedleLoader(
needle_api_key="your-api-key",
collection_id="your-collection-id"
)
# Load documents
documents = loader.load()
for doc in documents:
print(doc.metadata)
# Lazy load documents
for doc in loader.lazy_load():
print(doc.metadata)
"""
def __init__(
self,
needle_api_key: Optional[str] = None,
collection_id: Optional[str] = None,
) -> None:
"""
Initializes the NeedleLoader with API key and collection ID.
Args:
needle_api_key (Optional[str]): API key for authenticating with Needle.
collection_id (Optional[str]): Identifier for the Needle collection.
Raises:
ImportError: If the `needle-python` library is not installed.
ValueError: If the collection ID is not provided.
"""
try:
from needle.v1 import NeedleClient
except ImportError:
raise ImportError(
"Please install with `pip install needle-python` to use NeedleLoader."
)
super().__init__()
self.needle_api_key = needle_api_key
self.collection_id = collection_id
self.client: Optional[NeedleClient] = None
if self.needle_api_key:
self.client = NeedleClient(api_key=self.needle_api_key)
if not self.collection_id:
raise ValueError("Collection ID must be provided.")
def _get_collection(self) -> None:
"""
Ensures the Needle collection is set and the client is initialized.
Raises:
ValueError: If the Needle client is not initialized or
if the collection ID is missing.
"""
if self.client is None:
raise ValueError(
"NeedleClient is not initialized. Provide a valid API key."
)
if not self.collection_id:
raise ValueError("Collection ID must be provided.")
def add_files(self, files: Dict[str, str]) -> None:
"""
Adds files to the Needle collection.
Args:
files (Dict[str, str]): Dictionary where keys are file names and values
are file URLs.
Raises:
ImportError: If the `needle-python` library is not installed.
ValueError: If the collection is not properly initialized.
"""
try:
from needle.v1.models import FileToAdd
except ImportError:
raise ImportError(
"Please install with `pip install needle-python` to add files."
)
self._get_collection()
assert self.client is not None, "NeedleClient must be initialized."
files_to_add = [FileToAdd(name=name, url=url) for name, url in files.items()]
self.client.collections.files.add(
collection_id=self.collection_id, files=files_to_add
)
def _fetch_documents(self) -> List[Document]:
"""
Fetches metadata for documents from the Needle collection.
Returns:
List[Document]: A list of documents with metadata. Content is excluded.
Raises:
ValueError: If the collection is not properly initialized.
"""
self._get_collection()
assert self.client is not None, "NeedleClient must be initialized."
files = self.client.collections.files.list(self.collection_id)
docs = [
Document(
page_content="", # Needle doesn't provide file content fetching
metadata={
"source": file.url,
"title": file.name,
"size": getattr(file, "size", None),
},
)
for file in files
if file.status == "indexed"
]
return docs
def load(self) -> List[Document]:
"""
Loads all documents from the Needle collection.
Returns:
List[Document]: A list of documents from the collection.
"""
return self._fetch_documents()
def lazy_load(self) -> Iterator[Document]:
"""
Lazily loads documents from the Needle collection.
Yields:
Iterator[Document]: An iterator over the documents.
"""
yield from self._fetch_documents()

View File

@@ -93,6 +93,7 @@ if TYPE_CHECKING:
MilvusRetriever,
)
from langchain_community.retrievers.nanopq import NanoPQRetriever
from langchain_community.retrievers.needle import NeedleRetriever
from langchain_community.retrievers.outline import (
OutlineRetriever,
)
@@ -173,6 +174,7 @@ _module_lookup = {
"MetalRetriever": "langchain_community.retrievers.metal",
"MilvusRetriever": "langchain_community.retrievers.milvus",
"NanoPQRetriever": "langchain_community.retrievers.nanopq",
"NeedleRetriever": "langchain_community.retrievers.needle",
"OutlineRetriever": "langchain_community.retrievers.outline",
"PineconeHybridSearchRetriever": "langchain_community.retrievers.pinecone_hybrid_search", # noqa: E501
"PubMedRetriever": "langchain_community.retrievers.pubmed",
@@ -229,6 +231,7 @@ __all__ = [
"MetalRetriever",
"MilvusRetriever",
"NanoPQRetriever",
"NeedleRetriever",
"NeuralDBRetriever",
"OutlineRetriever",
"PineconeHybridSearchRetriever",

View File

@@ -0,0 +1,96 @@
from typing import Any, List, Optional # noqa: I001
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import BaseModel, Field
class NeedleRetriever(BaseRetriever, BaseModel):
"""
NeedleRetriever retrieves relevant documents or context from a Needle collection
based on a search query.
Setup:
Install the `needle-python` library and set your Needle API key.
.. code-block:: bash
pip install needle-python
export NEEDLE_API_KEY="your-api-key"
Key init args:
- `needle_api_key` (Optional[str]): The API key for authenticating with Needle.
- `collection_id` (str): The ID of the Needle collection to search in.
- `client` (Optional[NeedleClient]): An optional instance of the NeedleClient.
Usage:
.. code-block:: python
from langchain_community.retrievers.needle import NeedleRetriever
retriever = NeedleRetriever(
needle_api_key="your-api-key",
collection_id="your-collection-id"
)
results = retriever.retrieve("example query")
for doc in results:
print(doc.page_content)
"""
client: Optional[Any] = None
"""Optional instance of NeedleClient."""
needle_api_key: Optional[str] = Field(None, description="Needle API Key")
collection_id: Optional[str] = Field(
..., description="The ID of the Needle collection to search in"
)
def _initialize_client(self) -> None:
"""
Initialize the NeedleClient with the provided API key.
If a client instance is already provided, this method does nothing.
"""
try:
from needle.v1 import NeedleClient
except ImportError:
raise ImportError("Please install with `pip install needle-python`.")
if not self.client:
self.client = NeedleClient(api_key=self.needle_api_key)
def _search_collection(self, query: str) -> List[Document]:
"""
Search the Needle collection for relevant documents.
Args:
query (str): The search query used to find relevant documents.
Returns:
List[Document]: A list of documents matching the search query.
"""
self._initialize_client()
if self.client is None:
raise ValueError("NeedleClient is not initialized. Provide an API key.")
results = self.client.collections.search(
collection_id=self.collection_id, text=query
)
docs = [Document(page_content=result.content) for result in results]
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Retrieve relevant documents based on the query.
Args:
query (str): The query string used to search the collection.
Returns:
List[Document]: A list of documents relevant to the query.
"""
# The `run_manager` parameter is included to match the superclass signature,
# but it is not used in this implementation.
return self._search_collection(query)

View File

@@ -105,6 +105,7 @@ EXPECTED_ALL = [
"MergedDataLoader",
"ModernTreasuryLoader",
"MongodbLoader",
"NeedleLoader",
"NewsURLLoader",
"NotebookLoader",
"NotionDBLoader",

View File

@@ -0,0 +1,75 @@
import pytest
from pytest_mock import MockerFixture
@pytest.mark.requires("needle")
def test_add_and_fetch_files(mocker: MockerFixture) -> None:
"""
Test adding and fetching files using the NeedleLoader with a mock.
"""
from langchain_community.document_loaders.needle import NeedleLoader # noqa: I001
from needle.v1.models import CollectionFile # noqa: I001
# Create mock instances using mocker
# Create mock instances using mocker
mock_files = mocker.Mock()
mock_files.add.return_value = [
CollectionFile(
id="mock_id",
name="tech-radar-30.pdf",
url="https://example.com/",
status="indexed",
type="mock_type",
user_id="mock_user_id",
connector_id="mock_connector_id",
size=1234,
md5_hash="mock_md5_hash",
created_at="2024-01-01T00:00:00Z",
updated_at="2024-01-01T00:00:00Z",
)
]
mock_files.list.return_value = [
CollectionFile(
id="mock_id",
name="tech-radar-30.pdf",
url="https://example.com/",
status="indexed",
type="mock_type",
user_id="mock_user_id",
connector_id="mock_connector_id",
size=1234,
md5_hash="mock_md5_hash",
created_at="2024-01-01T00:00:00Z",
updated_at="2024-01-01T00:00:00Z",
)
]
mock_collections = mocker.Mock()
mock_collections.files = mock_files
mock_needle_client = mocker.Mock()
mock_needle_client.collections = mock_collections
# Patch the NeedleClient to return the mock client
mocker.patch("needle.v1.NeedleClient", return_value=mock_needle_client)
# Initialize NeedleLoader with mock API key and collection ID
document_store = NeedleLoader(
needle_api_key="fake_api_key",
collection_id="fake_collection_id",
)
# Define files to add
files = {
"tech-radar-30.pdf": "https://www.thoughtworks.com/content/dam/thoughtworks/documents/radar/2024/04/tr_technology_radar_vol_30_en.pdf"
}
# Add files to the collection using the mock client
document_store.add_files(files=files)
# Fetch the added files using the mock client
added_files = document_store._fetch_documents()
# Assertions to verify that the file was added and fetched correctly
assert isinstance(added_files[0].metadata["title"], str)
assert isinstance(added_files[0].metadata["source"], str)

View File

@@ -26,6 +26,7 @@ EXPECTED_ALL = [
"MetalRetriever",
"MilvusRetriever",
"NanoPQRetriever",
"NeedleRetriever",
"OutlineRetriever",
"PineconeHybridSearchRetriever",
"PubMedRetriever",

View File

@@ -0,0 +1,72 @@
from typing import Any
import pytest
from pytest_mock import MockerFixture
# Mock class to simulate search results from Needle API
class MockSearchResult:
def __init__(self, content: str) -> None:
self.content = content
# Mock class to simulate NeedleClient and its collections behavior
class MockNeedleClient:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.collections = self.MockCollections()
class MockCollections:
def search(self, collection_id: str, text: str) -> list[MockSearchResult]:
return [
MockSearchResult(content=f"Result for query: {text}"),
MockSearchResult(content=f"Another result for query: {text}"),
]
@pytest.mark.requires("needle")
def test_needle_retriever_initialization() -> None:
"""
Test that the NeedleRetriever is initialized correctly.
"""
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
retriever = NeedleRetriever(
needle_api_key="mock_api_key",
collection_id="mock_collection_id",
)
assert retriever.needle_api_key == "mock_api_key"
assert retriever.collection_id == "mock_collection_id"
@pytest.mark.requires("needle")
def test_get_relevant_documents(mocker: MockerFixture) -> None:
"""
Test that the retriever correctly fetches documents.
"""
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
# Patch the actual NeedleClient import path used in the NeedleRetriever
mocker.patch("needle.v1.NeedleClient", new=MockNeedleClient)
# Initialize the retriever with mocked API key and collection ID
retriever = NeedleRetriever(
needle_api_key="mock_api_key",
collection_id="mock_collection_id",
)
mock_run_manager: Any = None
# Perform the search
query = "What is RAG?"
retrieved_documents = retriever._get_relevant_documents(
query, run_manager=mock_run_manager
)
# Validate the results
assert len(retrieved_documents) == 2
assert retrieved_documents[0].page_content == "Result for query: What is RAG?"
assert (
retrieved_documents[1].page_content == "Another result for query: What is RAG?"
)