This commit is contained in:
Erick Friis 2024-12-04 17:34:41 -08:00
parent 0ab8e5cfe0
commit df4e0e6d81
4 changed files with 438 additions and 22 deletions

View File

@ -0,0 +1,206 @@
---
pagination_prev: contributing/how_to/integrations/index
pagination_next: contributing/how_to/integrations/publish
---
# How to implement and test a retriever integration
In this guide, we'll implement and test a custom [retriever](/docs/concepts/retrievers) that you have integrated with LangChain.
For testing, we will rely on the `langchain-tests` dependency we added in the previous [package creation guide](/docs/contributing/how_to/integrations/package).
## Implementation
Let's say you're building a simple integration package that provides a `ToyRetriever`
retriever integration for LangChain. Here's a simple example of what your project
structure might look like:
```plaintext
langchain-parrot-link/
├── langchain_parrot_link/
│ ├── __init__.py
│ └── retrievers.py
├── tests/
│ └── integration_tests
| ├── __init__.py
| └── test_retrievers.py
├── pyproject.toml
└── README.md
```
In this first step, we will implement the `retrievers.py` file
import CustomRetrieverIntro from '/docs/how_to/_custom_retriever_intro.mdx';
<CustomRetrieverIntro />
<details>
<summary>retrievers.py</summary>
```python title="langchain_parrot_link/retrievers.py"
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ParrotRetriever(BaseRetriever):
parrot_name: str
k: int = 3
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
k = kwargs.get("k", self.k)
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
```
</details>
:::tip
The `ParrotRetriever` from this guide is tested
against the standard unit and integration tests in the LangChain Github repository.
You can always use this as a starting point [here](https://github.com/langchain-ai/langchain/blob/master/libs/standard-tests/tests/unit_tests/test_basic_retriever.py).
:::
## Testing
### 1. Create Your Retriever Class
```python
from langchain.schema import BaseRetriever, Document
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
class MyCustomRetriever(BaseRetriever):
"""Custom retriever implementation."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Core implementation of retrieving relevant documents."""
# Your implementation here
pass
async def _aget_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Async implementation of retrieving relevant documents."""
# Your async implementation here
pass
```
### 2. Required Testing
All retrievers must include the following tests:
#### Basic Functionality Tests
```python
def test_get_relevant_documents():
retriever = MyCustomRetriever()
docs = retriever.get_relevant_documents("test query")
assert isinstance(docs, list)
assert all(isinstance(doc, Document) for doc in docs)
@pytest.mark.asyncio
async def test_aget_relevant_documents():
retriever = MyCustomRetriever()
docs = await retriever.aget_relevant_documents("test query")
assert isinstance(docs, list)
assert all(isinstance(doc, Document) for doc in docs)
```
#### Edge Cases
- Empty query handling
- Special character handling
- Long query handling
- Rate limiting (if applicable)
- Error handling
### 3. Documentation Requirements
Your retriever should include:
1. Class docstring with:
- General description
- Required dependencies
- Example usage
- Parameters explanation
2. Integration documentation file:
- Installation instructions
- Basic usage example
- Advanced configuration
- Common issues and solutions
### 4. Best Practices
1. **Error Handling**
- Implement proper error handling for API calls
- Provide meaningful error messages
- Handle rate limits gracefully
2. **Performance**
- Implement caching when appropriate
- Use batch operations where possible
- Consider implementing both sync and async methods
3. **Configuration**
- Use environment variables for sensitive data
- Provide sensible defaults
- Allow for customization of key parameters
4. **Type Hints**
- Use proper type hints throughout your code
- Document expected types in docstrings
## Example Implementation
Here's a minimal example of a custom retriever:
```python
from typing import List
from langchain.schema import BaseRetriever, Document
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
class SimpleKeywordRetriever(BaseRetriever):
"""A simple retriever that matches documents based on keywords."""
documents: List[Document] # Store your documents here
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Return documents that contain the query string."""
return [
doc for doc in self.documents
if query.lower() in doc.page_content.lower()
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Async version of get_relevant_documents."""
return self._get_relevant_documents(query, run_manager=run_manager)
```
## Submission Checklist
- [ ] Implemented base retriever interface
- [ ] Added comprehensive tests
- [ ] Included proper documentation
- [ ] Added type hints
- [ ] Handled error cases
- [ ] Implemented both sync and async methods
- [ ] Added example usage
- [ ] Followed code style guidelines
- [ ] Added requirements.txt or setup.py updates
## Getting Help
If you need help while implementing your retriever:
1. Check existing retriever implementations for reference
2. Open a discussion in the GitHub repository
3. Ask in the LangChain Discord community
Remember to follow the existing patterns in the codebase and maintain consistency with other retrievers.

View File

@ -0,0 +1,207 @@
# Standard Tests for LangChain Retrievers
This guide outlines the standard tests that should be implemented for all LangChain retrievers.
## Test Structure
### 1. Basic Functionality Tests
```python
import pytest
from langchain.schema import Document
from your_retriever import YourRetriever
def test_basic_retrieval():
"""Test basic document retrieval functionality."""
retriever = YourRetriever()
query = "test query"
docs = retriever.get_relevant_documents(query)
assert isinstance(docs, list)
assert all(isinstance(doc, Document) for doc in docs)
assert len(docs) > 0 # Adjust if your retriever might return empty results
@pytest.mark.asyncio
async def test_async_retrieval():
"""Test async document retrieval functionality."""
retriever = YourRetriever()
query = "test query"
docs = await retriever.aget_relevant_documents(query)
assert isinstance(docs, list)
assert all(isinstance(doc, Document) for doc in docs)
```
### 2. Edge Cases
```python
def test_empty_query():
"""Test behavior with empty query."""
retriever = YourRetriever()
docs = retriever.get_relevant_documents("")
assert isinstance(docs, list)
def test_special_characters():
"""Test handling of special characters."""
retriever = YourRetriever()
special_queries = [
"test!@#$%^&*()",
"múltiple áccents",
"中文测试",
"test\nwith\nnewlines",
]
for query in special_queries:
docs = retriever.get_relevant_documents(query)
assert isinstance(docs, list)
def test_long_query():
"""Test handling of very long queries."""
retriever = YourRetriever()
long_query = "test " * 1000
docs = retriever.get_relevant_documents(long_query)
assert isinstance(docs, list)
```
### 3. Error Handling
```python
def test_invalid_configuration():
"""Test behavior with invalid configuration."""
with pytest.raises(ValueError):
YourRetriever(invalid_param="invalid")
def test_connection_error():
"""Test behavior when connection fails (if applicable)."""
retriever = YourRetriever()
# Mock connection failure
with pytest.raises(ConnectionError):
retriever.get_relevant_documents("test")
```
### 4. Performance Tests (Optional)
```python
@pytest.mark.slow
def test_large_scale_retrieval():
"""Test retrieval with a large number of documents."""
retriever = YourRetriever()
# Test with a significant number of documents
docs = retriever.get_relevant_documents("test")
assert len(docs) <= YOUR_MAX_LIMIT # If applicable
@pytest.mark.slow
def test_concurrent_requests():
"""Test handling of concurrent requests."""
import asyncio
async def run_concurrent_requests():
retriever = YourRetriever()
tasks = [
retriever.aget_relevant_documents("test")
for _ in range(5)
]
results = await asyncio.gather(*tasks)
return results
results = asyncio.run(run_concurrent_requests())
assert len(results) == 5
```
### 5. Integration Tests
```python
def test_chain_integration():
"""Test integration with LangChain chains."""
from langchain.chains import RetrievalQA
from langchain.llms import FakeLLM
retriever = YourRetriever()
llm = FakeLLM()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff"
)
result = qa_chain.run("test query")
assert isinstance(result, str)
```
## Test Configuration
```python
# conftest.py
import pytest
def pytest_configure(config):
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)
@pytest.fixture
def sample_documents():
"""Fixture providing sample documents for testing."""
return [
Document(page_content="test document 1", metadata={"source": "test1"}),
Document(page_content="test document 2", metadata={"source": "test2"}),
]
@pytest.fixture
def mock_retriever(sample_documents):
"""Fixture providing a retriever with sample documents."""
retriever = YourRetriever()
# Set up retriever with sample documents
return retriever
```
## Running Tests
To run the tests:
```bash
# Run all tests
pytest tests/retrievers/test_your_retriever.py
# Run only fast tests
pytest tests/retrievers/test_your_retriever.py -m "not slow"
# Run with coverage
pytest tests/retrievers/test_your_retriever.py --cov=your_retriever
```
## Best Practices
1. **Isolation**: Each test should be independent and not rely on the state from other tests.
2. **Mocking**: Use mocks for external services to avoid actual API calls during testing:
```python
@pytest.fixture
def mock_api(mocker):
return mocker.patch("your_retriever.api_client")
```
3. **Parametrization**: Use pytest.mark.parametrize for testing multiple scenarios:
```python
@pytest.mark.parametrize("query,expected_count", [
("test", 1),
("invalid", 0),
("multiple words", 2),
])
def test_retrieval_counts(query, expected_count):
retriever = YourRetriever()
docs = retriever.get_relevant_documents(query)
assert len(docs) == expected_count
```
4. **Documentation**: Include docstrings in test functions explaining what they test.
5. **Coverage**: Aim for high test coverage, especially for core functionality.
## Common Pitfalls
1. Not testing error cases
2. Not testing async functionality
3. Not handling rate limits in tests
4. Missing edge cases
5. Relying on external services in unit tests
Remember to adapt these tests based on your retriever's specific functionality and requirements.

View File

@ -0,0 +1,23 @@
To create your own retriever, you need to extend the `BaseRetriever` class and implement the following methods:
| Method | Description | Required/Optional |
|--------------------------------|--------------------------------------------------|-------------------|
| `_get_relevant_documents` | Get documents relevant to a query. | Required |
| `_aget_relevant_documents` | Implement to provide async native support. | Optional |
The logic inside of `_get_relevant_documents` can involve arbitrary calls to a database or to the web using requests.
:::tip
By inherting from `BaseRetriever`, your retriever automatically becomes a LangChain [Runnable](/docs/concepts/runnables) and will gain the standard `Runnable` functionality out of the box!
:::
:::info
You can use a `RunnableLambda` or `RunnableGenerator` to implement a retriever.
The main benefit of implementing a retriever as a `BaseRetriever` vs. a `RunnableLambda` (a custom [runnable function](/docs/how_to/functions)) is that a `BaseRetriever` is a well
known LangChain entity so some tooling for monitoring may implement specialized behavior for retrievers. Another difference
is that a `BaseRetriever` will behave slightly differently from `RunnableLambda` in some APIs; e.g., the `start` event
in `astream_events` API will be `on_retriever_start` instead of `on_chain_start`.
:::

View File

@ -27,29 +27,9 @@
"\n",
"## Interface\n",
"\n",
"To create your own retriever, you need to extend the `BaseRetriever` class and implement the following methods:\n",
"import CustomRetrieverIntro from './_custom_retriever_intro.mdx';\n",
"\n",
"| Method | Description | Required/Optional |\n",
"|--------------------------------|--------------------------------------------------|-------------------|\n",
"| `_get_relevant_documents` | Get documents relevant to a query. | Required |\n",
"| `_aget_relevant_documents` | Implement to provide async native support. | Optional |\n",
"\n",
"\n",
"The logic inside of `_get_relevant_documents` can involve arbitrary calls to a database or to the web using requests.\n",
"\n",
":::tip\n",
"By inherting from `BaseRetriever`, your retriever automatically becomes a LangChain [Runnable](/docs/concepts/runnables) and will gain the standard `Runnable` functionality out of the box!\n",
":::\n",
"\n",
"\n",
":::info\n",
"You can use a `RunnableLambda` or `RunnableGenerator` to implement a retriever.\n",
"\n",
"The main benefit of implementing a retriever as a `BaseRetriever` vs. a `RunnableLambda` (a custom [runnable function](/docs/how_to/functions)) is that a `BaseRetriever` is a well\n",
"known LangChain entity so some tooling for monitoring may implement specialized behavior for retrievers. Another difference\n",
"is that a `BaseRetriever` will behave slightly differently from `RunnableLambda` in some APIs; e.g., the `start` event\n",
"in `astream_events` API will be `on_retriever_start` instead of `on_chain_start`.\n",
":::\n"
"<CustomRetrieverIntro />"
]
},
{