mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
community: add Needle retriever and document loader integration (#28157)
- [x] **PR title**: "community: add Needle retriever and document loader integration" - Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** This PR adds a new integration for Needle, which includes: - **NeedleRetriever**: A retriever for fetching documents from Needle collections. - **NeedleLoader**: A document loader for managing and loading documents into Needle collections. - Example notebooks demonstrating usage have been added in: - `docs/docs/integrations/retrievers/needle.ipynb` - `docs/docs/integrations/document_loaders/needle.ipynb`. - **Dependencies:** The `needle-python` package is required as an external dependency for accessing Needle's API. It has been added to the extended testing dependencies list. - **Twitter handle:** Feel free to mention me if this PR gets announced: [needlexai](https://x.com/NeedlexAI). - [x] **Add tests and docs**: If you're adding a new integration, please include 1. Unit tests have been added for both `NeedleRetriever` and `NeedleLoader` in `libs/community/tests/unit_tests`. These tests mock API calls to avoid relying on network access. 2. Example notebooks have been added to `docs/docs/integrations/`, showcasing both retriever and loader functionality. - [x] **Lint and test**: Run `make format`, `make lint`, and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ - `make format`: Passed - `make lint`: Passed - `make test`: Passed (requires `needle-python` to be installed locally; this package is not added to LangChain dependencies). Additional guidelines: - [x] Optional dependencies are imported only within functions. - [x] No dependencies have been added to pyproject.toml files except for those required for unit tests. - [x] The PR does not touch more than one package. - [x] Changes are fully backwards compatible. - [x] Community additions are not re-imported into LangChain core. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -46,6 +46,7 @@ motor>=3.3.1,<4
|
||||
msal>=1.25.0,<2
|
||||
mwparserfromhell>=0.6.4,<0.7
|
||||
mwxml>=0.3.3,<0.4
|
||||
needle-python>=0.4
|
||||
networkx>=3.2.1,<4
|
||||
newspaper3k>=0.2.8,<0.3
|
||||
numexpr>=2.8.6,<3
|
||||
|
@@ -299,6 +299,9 @@ if TYPE_CHECKING:
|
||||
from langchain_community.document_loaders.mongodb import (
|
||||
MongodbLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.needle import (
|
||||
NeedleLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.news import (
|
||||
NewsURLLoader,
|
||||
)
|
||||
@@ -631,6 +634,7 @@ _module_lookup = {
|
||||
"MergedDataLoader": "langchain_community.document_loaders.merge",
|
||||
"ModernTreasuryLoader": "langchain_community.document_loaders.modern_treasury",
|
||||
"MongodbLoader": "langchain_community.document_loaders.mongodb",
|
||||
"NeedleLoader": "langchain_community.document_loaders.needle",
|
||||
"NewsURLLoader": "langchain_community.document_loaders.news",
|
||||
"NotebookLoader": "langchain_community.document_loaders.notebook",
|
||||
"NotionDBLoader": "langchain_community.document_loaders.notiondb",
|
||||
@@ -837,6 +841,7 @@ __all__ = [
|
||||
"MergedDataLoader",
|
||||
"ModernTreasuryLoader",
|
||||
"MongodbLoader",
|
||||
"NeedleLoader",
|
||||
"NewsURLLoader",
|
||||
"NotebookLoader",
|
||||
"NotionDBLoader",
|
||||
|
164
libs/community/langchain_community/document_loaders/needle.py
Normal file
164
libs/community/langchain_community/document_loaders/needle.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class NeedleLoader(BaseLoader):
|
||||
"""
|
||||
NeedleLoader is a document loader for managing documents stored in a collection.
|
||||
|
||||
Setup:
|
||||
Install the `needle-python` library and set your Needle API key.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install needle-python
|
||||
export NEEDLE_API_KEY="your-api-key"
|
||||
|
||||
Key init args:
|
||||
- `needle_api_key` (Optional[str]): API key for authenticating with Needle.
|
||||
- `collection_id` (str): Needle collection to load documents from.
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.document_loaders.needle import NeedleLoader
|
||||
|
||||
loader = NeedleLoader(
|
||||
needle_api_key="your-api-key",
|
||||
collection_id="your-collection-id"
|
||||
)
|
||||
|
||||
# Load documents
|
||||
documents = loader.load()
|
||||
for doc in documents:
|
||||
print(doc.metadata)
|
||||
|
||||
# Lazy load documents
|
||||
for doc in loader.lazy_load():
|
||||
print(doc.metadata)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
needle_api_key: Optional[str] = None,
|
||||
collection_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the NeedleLoader with API key and collection ID.
|
||||
|
||||
Args:
|
||||
needle_api_key (Optional[str]): API key for authenticating with Needle.
|
||||
collection_id (Optional[str]): Identifier for the Needle collection.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `needle-python` library is not installed.
|
||||
ValueError: If the collection ID is not provided.
|
||||
"""
|
||||
try:
|
||||
from needle.v1 import NeedleClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install with `pip install needle-python` to use NeedleLoader."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.needle_api_key = needle_api_key
|
||||
self.collection_id = collection_id
|
||||
self.client: Optional[NeedleClient] = None
|
||||
|
||||
if self.needle_api_key:
|
||||
self.client = NeedleClient(api_key=self.needle_api_key)
|
||||
|
||||
if not self.collection_id:
|
||||
raise ValueError("Collection ID must be provided.")
|
||||
|
||||
def _get_collection(self) -> None:
|
||||
"""
|
||||
Ensures the Needle collection is set and the client is initialized.
|
||||
|
||||
Raises:
|
||||
ValueError: If the Needle client is not initialized or
|
||||
if the collection ID is missing.
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ValueError(
|
||||
"NeedleClient is not initialized. Provide a valid API key."
|
||||
)
|
||||
if not self.collection_id:
|
||||
raise ValueError("Collection ID must be provided.")
|
||||
|
||||
def add_files(self, files: Dict[str, str]) -> None:
|
||||
"""
|
||||
Adds files to the Needle collection.
|
||||
|
||||
Args:
|
||||
files (Dict[str, str]): Dictionary where keys are file names and values
|
||||
are file URLs.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `needle-python` library is not installed.
|
||||
ValueError: If the collection is not properly initialized.
|
||||
"""
|
||||
try:
|
||||
from needle.v1.models import FileToAdd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install with `pip install needle-python` to add files."
|
||||
)
|
||||
|
||||
self._get_collection()
|
||||
assert self.client is not None, "NeedleClient must be initialized."
|
||||
|
||||
files_to_add = [FileToAdd(name=name, url=url) for name, url in files.items()]
|
||||
|
||||
self.client.collections.files.add(
|
||||
collection_id=self.collection_id, files=files_to_add
|
||||
)
|
||||
|
||||
def _fetch_documents(self) -> List[Document]:
|
||||
"""
|
||||
Fetches metadata for documents from the Needle collection.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents with metadata. Content is excluded.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection is not properly initialized.
|
||||
"""
|
||||
self._get_collection()
|
||||
assert self.client is not None, "NeedleClient must be initialized."
|
||||
|
||||
files = self.client.collections.files.list(self.collection_id)
|
||||
docs = [
|
||||
Document(
|
||||
page_content="", # Needle doesn't provide file content fetching
|
||||
metadata={
|
||||
"source": file.url,
|
||||
"title": file.name,
|
||||
"size": getattr(file, "size", None),
|
||||
},
|
||||
)
|
||||
for file in files
|
||||
if file.status == "indexed"
|
||||
]
|
||||
return docs
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Loads all documents from the Needle collection.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents from the collection.
|
||||
"""
|
||||
return self._fetch_documents()
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""
|
||||
Lazily loads documents from the Needle collection.
|
||||
|
||||
Yields:
|
||||
Iterator[Document]: An iterator over the documents.
|
||||
"""
|
||||
yield from self._fetch_documents()
|
@@ -93,6 +93,7 @@ if TYPE_CHECKING:
|
||||
MilvusRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.nanopq import NanoPQRetriever
|
||||
from langchain_community.retrievers.needle import NeedleRetriever
|
||||
from langchain_community.retrievers.outline import (
|
||||
OutlineRetriever,
|
||||
)
|
||||
@@ -173,6 +174,7 @@ _module_lookup = {
|
||||
"MetalRetriever": "langchain_community.retrievers.metal",
|
||||
"MilvusRetriever": "langchain_community.retrievers.milvus",
|
||||
"NanoPQRetriever": "langchain_community.retrievers.nanopq",
|
||||
"NeedleRetriever": "langchain_community.retrievers.needle",
|
||||
"OutlineRetriever": "langchain_community.retrievers.outline",
|
||||
"PineconeHybridSearchRetriever": "langchain_community.retrievers.pinecone_hybrid_search", # noqa: E501
|
||||
"PubMedRetriever": "langchain_community.retrievers.pubmed",
|
||||
@@ -229,6 +231,7 @@ __all__ = [
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"NanoPQRetriever",
|
||||
"NeedleRetriever",
|
||||
"NeuralDBRetriever",
|
||||
"OutlineRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
|
96
libs/community/langchain_community/retrievers/needle.py
Normal file
96
libs/community/langchain_community/retrievers/needle.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Any, List, Optional # noqa: I001
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NeedleRetriever(BaseRetriever, BaseModel):
|
||||
"""
|
||||
NeedleRetriever retrieves relevant documents or context from a Needle collection
|
||||
based on a search query.
|
||||
|
||||
Setup:
|
||||
Install the `needle-python` library and set your Needle API key.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install needle-python
|
||||
export NEEDLE_API_KEY="your-api-key"
|
||||
|
||||
Key init args:
|
||||
- `needle_api_key` (Optional[str]): The API key for authenticating with Needle.
|
||||
- `collection_id` (str): The ID of the Needle collection to search in.
|
||||
- `client` (Optional[NeedleClient]): An optional instance of the NeedleClient.
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers.needle import NeedleRetriever
|
||||
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="your-api-key",
|
||||
collection_id="your-collection-id"
|
||||
)
|
||||
|
||||
results = retriever.retrieve("example query")
|
||||
for doc in results:
|
||||
print(doc.page_content)
|
||||
"""
|
||||
|
||||
client: Optional[Any] = None
|
||||
"""Optional instance of NeedleClient."""
|
||||
needle_api_key: Optional[str] = Field(None, description="Needle API Key")
|
||||
collection_id: Optional[str] = Field(
|
||||
..., description="The ID of the Needle collection to search in"
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
"""
|
||||
Initialize the NeedleClient with the provided API key.
|
||||
|
||||
If a client instance is already provided, this method does nothing.
|
||||
"""
|
||||
try:
|
||||
from needle.v1 import NeedleClient
|
||||
except ImportError:
|
||||
raise ImportError("Please install with `pip install needle-python`.")
|
||||
|
||||
if not self.client:
|
||||
self.client = NeedleClient(api_key=self.needle_api_key)
|
||||
|
||||
def _search_collection(self, query: str) -> List[Document]:
|
||||
"""
|
||||
Search the Needle collection for relevant documents.
|
||||
|
||||
Args:
|
||||
query (str): The search query used to find relevant documents.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents matching the search query.
|
||||
"""
|
||||
self._initialize_client()
|
||||
if self.client is None:
|
||||
raise ValueError("NeedleClient is not initialized. Provide an API key.")
|
||||
|
||||
results = self.client.collections.search(
|
||||
collection_id=self.collection_id, text=query
|
||||
)
|
||||
docs = [Document(page_content=result.content) for result in results]
|
||||
return docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve relevant documents based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The query string used to search the collection.
|
||||
Returns:
|
||||
List[Document]: A list of documents relevant to the query.
|
||||
"""
|
||||
# The `run_manager` parameter is included to match the superclass signature,
|
||||
# but it is not used in this implementation.
|
||||
return self._search_collection(query)
|
@@ -105,6 +105,7 @@ EXPECTED_ALL = [
|
||||
"MergedDataLoader",
|
||||
"ModernTreasuryLoader",
|
||||
"MongodbLoader",
|
||||
"NeedleLoader",
|
||||
"NewsURLLoader",
|
||||
"NotebookLoader",
|
||||
"NotionDBLoader",
|
||||
|
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
def test_add_and_fetch_files(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test adding and fetching files using the NeedleLoader with a mock.
|
||||
"""
|
||||
from langchain_community.document_loaders.needle import NeedleLoader # noqa: I001
|
||||
from needle.v1.models import CollectionFile # noqa: I001
|
||||
|
||||
# Create mock instances using mocker
|
||||
# Create mock instances using mocker
|
||||
mock_files = mocker.Mock()
|
||||
mock_files.add.return_value = [
|
||||
CollectionFile(
|
||||
id="mock_id",
|
||||
name="tech-radar-30.pdf",
|
||||
url="https://example.com/",
|
||||
status="indexed",
|
||||
type="mock_type",
|
||||
user_id="mock_user_id",
|
||||
connector_id="mock_connector_id",
|
||||
size=1234,
|
||||
md5_hash="mock_md5_hash",
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
]
|
||||
mock_files.list.return_value = [
|
||||
CollectionFile(
|
||||
id="mock_id",
|
||||
name="tech-radar-30.pdf",
|
||||
url="https://example.com/",
|
||||
status="indexed",
|
||||
type="mock_type",
|
||||
user_id="mock_user_id",
|
||||
connector_id="mock_connector_id",
|
||||
size=1234,
|
||||
md5_hash="mock_md5_hash",
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
mock_collections = mocker.Mock()
|
||||
mock_collections.files = mock_files
|
||||
|
||||
mock_needle_client = mocker.Mock()
|
||||
mock_needle_client.collections = mock_collections
|
||||
|
||||
# Patch the NeedleClient to return the mock client
|
||||
mocker.patch("needle.v1.NeedleClient", return_value=mock_needle_client)
|
||||
|
||||
# Initialize NeedleLoader with mock API key and collection ID
|
||||
document_store = NeedleLoader(
|
||||
needle_api_key="fake_api_key",
|
||||
collection_id="fake_collection_id",
|
||||
)
|
||||
|
||||
# Define files to add
|
||||
files = {
|
||||
"tech-radar-30.pdf": "https://www.thoughtworks.com/content/dam/thoughtworks/documents/radar/2024/04/tr_technology_radar_vol_30_en.pdf"
|
||||
}
|
||||
|
||||
# Add files to the collection using the mock client
|
||||
document_store.add_files(files=files)
|
||||
|
||||
# Fetch the added files using the mock client
|
||||
added_files = document_store._fetch_documents()
|
||||
|
||||
# Assertions to verify that the file was added and fetched correctly
|
||||
assert isinstance(added_files[0].metadata["title"], str)
|
||||
assert isinstance(added_files[0].metadata["source"], str)
|
@@ -26,6 +26,7 @@ EXPECTED_ALL = [
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"NanoPQRetriever",
|
||||
"NeedleRetriever",
|
||||
"OutlineRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"PubMedRetriever",
|
||||
|
72
libs/community/tests/unit_tests/retrievers/test_needle.py
Normal file
72
libs/community/tests/unit_tests/retrievers/test_needle.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
# Mock class to simulate search results from Needle API
|
||||
class MockSearchResult:
|
||||
def __init__(self, content: str) -> None:
|
||||
self.content = content
|
||||
|
||||
|
||||
# Mock class to simulate NeedleClient and its collections behavior
|
||||
class MockNeedleClient:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.collections = self.MockCollections()
|
||||
|
||||
class MockCollections:
|
||||
def search(self, collection_id: str, text: str) -> list[MockSearchResult]:
|
||||
return [
|
||||
MockSearchResult(content=f"Result for query: {text}"),
|
||||
MockSearchResult(content=f"Another result for query: {text}"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
def test_needle_retriever_initialization() -> None:
|
||||
"""
|
||||
Test that the NeedleRetriever is initialized correctly.
|
||||
"""
|
||||
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
|
||||
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="mock_api_key",
|
||||
collection_id="mock_collection_id",
|
||||
)
|
||||
|
||||
assert retriever.needle_api_key == "mock_api_key"
|
||||
assert retriever.collection_id == "mock_collection_id"
|
||||
|
||||
|
||||
@pytest.mark.requires("needle")
|
||||
def test_get_relevant_documents(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that the retriever correctly fetches documents.
|
||||
"""
|
||||
from langchain_community.retrievers.needle import NeedleRetriever # noqa: I001
|
||||
|
||||
# Patch the actual NeedleClient import path used in the NeedleRetriever
|
||||
mocker.patch("needle.v1.NeedleClient", new=MockNeedleClient)
|
||||
|
||||
# Initialize the retriever with mocked API key and collection ID
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="mock_api_key",
|
||||
collection_id="mock_collection_id",
|
||||
)
|
||||
|
||||
mock_run_manager: Any = None
|
||||
|
||||
# Perform the search
|
||||
query = "What is RAG?"
|
||||
retrieved_documents = retriever._get_relevant_documents(
|
||||
query, run_manager=mock_run_manager
|
||||
)
|
||||
|
||||
# Validate the results
|
||||
assert len(retrieved_documents) == 2
|
||||
assert retrieved_documents[0].page_content == "Result for query: What is RAG?"
|
||||
assert (
|
||||
retrieved_documents[1].page_content == "Another result for query: What is RAG?"
|
||||
)
|
Reference in New Issue
Block a user