From df4e0e6d8181a5c113829338b0559f371481c2d4 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Wed, 4 Dec 2024 17:34:41 -0800 Subject: [PATCH] x --- .../how_to/integrations/retriever_guide.md | 206 +++++++++++++++++ .../how_to/integrations/retriever_tests.md | 207 ++++++++++++++++++ docs/docs/how_to/_custom_retriever_intro.mdx | 23 ++ docs/docs/how_to/custom_retriever.ipynb | 24 +- 4 files changed, 438 insertions(+), 22 deletions(-) create mode 100644 docs/docs/contributing/how_to/integrations/retriever_guide.md create mode 100644 docs/docs/contributing/how_to/integrations/retriever_tests.md create mode 100644 docs/docs/how_to/_custom_retriever_intro.mdx diff --git a/docs/docs/contributing/how_to/integrations/retriever_guide.md b/docs/docs/contributing/how_to/integrations/retriever_guide.md new file mode 100644 index 00000000000..ab3313bdf01 --- /dev/null +++ b/docs/docs/contributing/how_to/integrations/retriever_guide.md @@ -0,0 +1,206 @@ +--- +pagination_prev: contributing/how_to/integrations/index +pagination_next: contributing/how_to/integrations/publish +--- +# How to implement and test a retriever integration + +In this guide, we'll implement and test a custom [retriever](/docs/concepts/retrievers) that you have integrated with LangChain. + +For testing, we will rely on the `langchain-tests` dependency we added in the previous [package creation guide](/docs/contributing/how_to/integrations/package). + +## Implementation + +Let's say you're building a simple integration package that provides a `ToyRetriever` +retriever integration for LangChain. Here's a simple example of what your project +structure might look like: + +```plaintext +langchain-parrot-link/ +├── langchain_parrot_link/ +│ ├── __init__.py +│ └── retrievers.py +├── tests/ +│ └── integration_tests +| ├── __init__.py +| └── test_retrievers.py +├── pyproject.toml +└── README.md +``` + +In this first step, we will implement the `retrievers.py` file + +import CustomRetrieverIntro from '/docs/how_to/_custom_retriever_intro.mdx'; + + + +
+ retrievers.py +```python title="langchain_parrot_link/retrievers.py" +from typing import Any + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +class ParrotRetriever(BaseRetriever): + parrot_name: str + k: int = 3 + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + ) -> list[Document]: + k = kwargs.get("k", self.k) + return [Document(page_content=f"{self.parrot_name} says: {query}")] * k +``` +
+ +:::tip + +The `ParrotRetriever` from this guide is tested +against the standard unit and integration tests in the LangChain Github repository. +You can always use this as a starting point [here](https://github.com/langchain-ai/langchain/blob/master/libs/standard-tests/tests/unit_tests/test_basic_retriever.py). + +::: + +## Testing + + + +### 1. Create Your Retriever Class + +```python +from langchain.schema import BaseRetriever, Document +from langchain.callbacks.manager import CallbackManagerForRetrieverRun + +class MyCustomRetriever(BaseRetriever): + """Custom retriever implementation.""" + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Core implementation of retrieving relevant documents.""" + # Your implementation here + pass + + async def _aget_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Async implementation of retrieving relevant documents.""" + # Your async implementation here + pass +``` + +### 2. Required Testing + +All retrievers must include the following tests: + +#### Basic Functionality Tests +```python +def test_get_relevant_documents(): + retriever = MyCustomRetriever() + docs = retriever.get_relevant_documents("test query") + assert isinstance(docs, list) + assert all(isinstance(doc, Document) for doc in docs) + +@pytest.mark.asyncio +async def test_aget_relevant_documents(): + retriever = MyCustomRetriever() + docs = await retriever.aget_relevant_documents("test query") + assert isinstance(docs, list) + assert all(isinstance(doc, Document) for doc in docs) +``` + +#### Edge Cases +- Empty query handling +- Special character handling +- Long query handling +- Rate limiting (if applicable) +- Error handling + +### 3. Documentation Requirements + +Your retriever should include: + +1. Class docstring with: + - General description + - Required dependencies + - Example usage + - Parameters explanation + +2. Integration documentation file: + - Installation instructions + - Basic usage example + - Advanced configuration + - Common issues and solutions + +### 4. Best Practices + +1. **Error Handling** + - Implement proper error handling for API calls + - Provide meaningful error messages + - Handle rate limits gracefully + +2. **Performance** + - Implement caching when appropriate + - Use batch operations where possible + - Consider implementing both sync and async methods + +3. **Configuration** + - Use environment variables for sensitive data + - Provide sensible defaults + - Allow for customization of key parameters + +4. **Type Hints** + - Use proper type hints throughout your code + - Document expected types in docstrings + +## Example Implementation + +Here's a minimal example of a custom retriever: + +```python +from typing import List +from langchain.schema import BaseRetriever, Document +from langchain.callbacks.manager import CallbackManagerForRetrieverRun + +class SimpleKeywordRetriever(BaseRetriever): + """A simple retriever that matches documents based on keywords.""" + + documents: List[Document] # Store your documents here + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Return documents that contain the query string.""" + return [ + doc for doc in self.documents + if query.lower() in doc.page_content.lower() + ] + + async def _aget_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Async version of get_relevant_documents.""" + return self._get_relevant_documents(query, run_manager=run_manager) +``` + +## Submission Checklist + +- [ ] Implemented base retriever interface +- [ ] Added comprehensive tests +- [ ] Included proper documentation +- [ ] Added type hints +- [ ] Handled error cases +- [ ] Implemented both sync and async methods +- [ ] Added example usage +- [ ] Followed code style guidelines +- [ ] Added requirements.txt or setup.py updates + +## Getting Help + +If you need help while implementing your retriever: +1. Check existing retriever implementations for reference +2. Open a discussion in the GitHub repository +3. Ask in the LangChain Discord community + +Remember to follow the existing patterns in the codebase and maintain consistency with other retrievers. diff --git a/docs/docs/contributing/how_to/integrations/retriever_tests.md b/docs/docs/contributing/how_to/integrations/retriever_tests.md new file mode 100644 index 00000000000..1c565735e9b --- /dev/null +++ b/docs/docs/contributing/how_to/integrations/retriever_tests.md @@ -0,0 +1,207 @@ +# Standard Tests for LangChain Retrievers + +This guide outlines the standard tests that should be implemented for all LangChain retrievers. + +## Test Structure + +### 1. Basic Functionality Tests + +```python +import pytest +from langchain.schema import Document +from your_retriever import YourRetriever + +def test_basic_retrieval(): + """Test basic document retrieval functionality.""" + retriever = YourRetriever() + query = "test query" + docs = retriever.get_relevant_documents(query) + + assert isinstance(docs, list) + assert all(isinstance(doc, Document) for doc in docs) + assert len(docs) > 0 # Adjust if your retriever might return empty results + +@pytest.mark.asyncio +async def test_async_retrieval(): + """Test async document retrieval functionality.""" + retriever = YourRetriever() + query = "test query" + docs = await retriever.aget_relevant_documents(query) + + assert isinstance(docs, list) + assert all(isinstance(doc, Document) for doc in docs) +``` + +### 2. Edge Cases + +```python +def test_empty_query(): + """Test behavior with empty query.""" + retriever = YourRetriever() + docs = retriever.get_relevant_documents("") + assert isinstance(docs, list) + +def test_special_characters(): + """Test handling of special characters.""" + retriever = YourRetriever() + special_queries = [ + "test!@#$%^&*()", + "múltiple áccents", + "中文测试", + "test\nwith\nnewlines", + ] + for query in special_queries: + docs = retriever.get_relevant_documents(query) + assert isinstance(docs, list) + +def test_long_query(): + """Test handling of very long queries.""" + retriever = YourRetriever() + long_query = "test " * 1000 + docs = retriever.get_relevant_documents(long_query) + assert isinstance(docs, list) +``` + +### 3. Error Handling + +```python +def test_invalid_configuration(): + """Test behavior with invalid configuration.""" + with pytest.raises(ValueError): + YourRetriever(invalid_param="invalid") + +def test_connection_error(): + """Test behavior when connection fails (if applicable).""" + retriever = YourRetriever() + # Mock connection failure + with pytest.raises(ConnectionError): + retriever.get_relevant_documents("test") +``` + +### 4. Performance Tests (Optional) + +```python +@pytest.mark.slow +def test_large_scale_retrieval(): + """Test retrieval with a large number of documents.""" + retriever = YourRetriever() + # Test with a significant number of documents + docs = retriever.get_relevant_documents("test") + assert len(docs) <= YOUR_MAX_LIMIT # If applicable + +@pytest.mark.slow +def test_concurrent_requests(): + """Test handling of concurrent requests.""" + import asyncio + + async def run_concurrent_requests(): + retriever = YourRetriever() + tasks = [ + retriever.aget_relevant_documents("test") + for _ in range(5) + ] + results = await asyncio.gather(*tasks) + return results + + results = asyncio.run(run_concurrent_requests()) + assert len(results) == 5 +``` + +### 5. Integration Tests + +```python +def test_chain_integration(): + """Test integration with LangChain chains.""" + from langchain.chains import RetrievalQA + from langchain.llms import FakeLLM + + retriever = YourRetriever() + llm = FakeLLM() + qa_chain = RetrievalQA.from_chain_type( + llm=llm, + retriever=retriever, + chain_type="stuff" + ) + result = qa_chain.run("test query") + assert isinstance(result, str) +``` + +## Test Configuration + +```python +# conftest.py +import pytest + +def pytest_configure(config): + config.addinivalue_line( + "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" + ) + +@pytest.fixture +def sample_documents(): + """Fixture providing sample documents for testing.""" + return [ + Document(page_content="test document 1", metadata={"source": "test1"}), + Document(page_content="test document 2", metadata={"source": "test2"}), + ] + +@pytest.fixture +def mock_retriever(sample_documents): + """Fixture providing a retriever with sample documents.""" + retriever = YourRetriever() + # Set up retriever with sample documents + return retriever +``` + +## Running Tests + +To run the tests: + +```bash +# Run all tests +pytest tests/retrievers/test_your_retriever.py + +# Run only fast tests +pytest tests/retrievers/test_your_retriever.py -m "not slow" + +# Run with coverage +pytest tests/retrievers/test_your_retriever.py --cov=your_retriever +``` + +## Best Practices + +1. **Isolation**: Each test should be independent and not rely on the state from other tests. + +2. **Mocking**: Use mocks for external services to avoid actual API calls during testing: + ```python + @pytest.fixture + def mock_api(mocker): + return mocker.patch("your_retriever.api_client") + ``` + +3. **Parametrization**: Use pytest.mark.parametrize for testing multiple scenarios: + ```python + @pytest.mark.parametrize("query,expected_count", [ + ("test", 1), + ("invalid", 0), + ("multiple words", 2), + ]) + def test_retrieval_counts(query, expected_count): + retriever = YourRetriever() + docs = retriever.get_relevant_documents(query) + assert len(docs) == expected_count + ``` + +4. **Documentation**: Include docstrings in test functions explaining what they test. + +5. **Coverage**: Aim for high test coverage, especially for core functionality. + +## Common Pitfalls + +1. Not testing error cases +2. Not testing async functionality +3. Not handling rate limits in tests +4. Missing edge cases +5. Relying on external services in unit tests + +Remember to adapt these tests based on your retriever's specific functionality and requirements. diff --git a/docs/docs/how_to/_custom_retriever_intro.mdx b/docs/docs/how_to/_custom_retriever_intro.mdx new file mode 100644 index 00000000000..5e1bf4f3e3a --- /dev/null +++ b/docs/docs/how_to/_custom_retriever_intro.mdx @@ -0,0 +1,23 @@ +To create your own retriever, you need to extend the `BaseRetriever` class and implement the following methods: + +| Method | Description | Required/Optional | +|--------------------------------|--------------------------------------------------|-------------------| +| `_get_relevant_documents` | Get documents relevant to a query. | Required | +| `_aget_relevant_documents` | Implement to provide async native support. | Optional | + + +The logic inside of `_get_relevant_documents` can involve arbitrary calls to a database or to the web using requests. + +:::tip +By inherting from `BaseRetriever`, your retriever automatically becomes a LangChain [Runnable](/docs/concepts/runnables) and will gain the standard `Runnable` functionality out of the box! +::: + + +:::info +You can use a `RunnableLambda` or `RunnableGenerator` to implement a retriever. + +The main benefit of implementing a retriever as a `BaseRetriever` vs. a `RunnableLambda` (a custom [runnable function](/docs/how_to/functions)) is that a `BaseRetriever` is a well +known LangChain entity so some tooling for monitoring may implement specialized behavior for retrievers. Another difference +is that a `BaseRetriever` will behave slightly differently from `RunnableLambda` in some APIs; e.g., the `start` event +in `astream_events` API will be `on_retriever_start` instead of `on_chain_start`. +::: diff --git a/docs/docs/how_to/custom_retriever.ipynb b/docs/docs/how_to/custom_retriever.ipynb index 31b6fb90a1c..d7b1a379609 100644 --- a/docs/docs/how_to/custom_retriever.ipynb +++ b/docs/docs/how_to/custom_retriever.ipynb @@ -27,29 +27,9 @@ "\n", "## Interface\n", "\n", - "To create your own retriever, you need to extend the `BaseRetriever` class and implement the following methods:\n", + "import CustomRetrieverIntro from './_custom_retriever_intro.mdx';\n", "\n", - "| Method | Description | Required/Optional |\n", - "|--------------------------------|--------------------------------------------------|-------------------|\n", - "| `_get_relevant_documents` | Get documents relevant to a query. | Required |\n", - "| `_aget_relevant_documents` | Implement to provide async native support. | Optional |\n", - "\n", - "\n", - "The logic inside of `_get_relevant_documents` can involve arbitrary calls to a database or to the web using requests.\n", - "\n", - ":::tip\n", - "By inherting from `BaseRetriever`, your retriever automatically becomes a LangChain [Runnable](/docs/concepts/runnables) and will gain the standard `Runnable` functionality out of the box!\n", - ":::\n", - "\n", - "\n", - ":::info\n", - "You can use a `RunnableLambda` or `RunnableGenerator` to implement a retriever.\n", - "\n", - "The main benefit of implementing a retriever as a `BaseRetriever` vs. a `RunnableLambda` (a custom [runnable function](/docs/how_to/functions)) is that a `BaseRetriever` is a well\n", - "known LangChain entity so some tooling for monitoring may implement specialized behavior for retrievers. Another difference\n", - "is that a `BaseRetriever` will behave slightly differently from `RunnableLambda` in some APIs; e.g., the `start` event\n", - "in `astream_events` API will be `on_retriever_start` instead of `on_chain_start`.\n", - ":::\n" + "" ] }, {