mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 05:52:15 +00:00
Add New Retriever Interface with Callbacks (#5962)
Handle the new retriever events in a way that (I think) is entirely backwards compatible? Needs more testing for some of the chain changes and all. This creates an entire new run type, however. We could also just treat this as an event within a chain run presumably (same with memory) Adds a subclass initializer that upgrades old retriever implementations to the new schema, along with tests to ensure they work. First commit doesn't upgrade any of our retriever implementations (to show that we can pass the tests along with additional ones testing the upgrade logic). Second commit upgrades the known universe of retrievers in langchain. - [X] Add callback handling methods for retriever start/end/error (open to renaming to 'retrieval' if you want that) - [X] Update BaseRetriever schema to support callbacks - [X] Tests for upgrading old "v1" retrievers for backwards compatibility - [X] Update existing retriever implementations to implement the new interface - [X] Update calls within chains to .{a]get_relevant_documents to pass the child callback manager - [X] Update the notebooks/docs to reflect the new interface - [X] Test notebooks thoroughly Not handled: - Memory pass throughs: retrieval memory doesn't have a parent callback manager passed through the method --------- Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
@@ -1,24 +1,40 @@
|
||||
The `BaseRetriever` class in LangChain is as follows:
|
||||
The public API of the `BaseRetriever` class in LangChain is as follows:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
from langchain.schema import Document
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get texts relevant for a query.
|
||||
|
||||
...
|
||||
def get_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant texts for
|
||||
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
...
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
It's that simple! The `get_relevant_documents` method can be implemented however you see fit.
|
||||
It's that simple! You can call `get_relevant_documents` or the async `get_relevant_documents` methods to retrieve documents relevant to a query, where "relevance" is defined by
|
||||
the specific retriever object you are calling.
|
||||
|
||||
Of course, we also help construct what we think useful Retrievers are. The main type of Retriever that we focus on is a Vectorstore retriever. We will focus on that for the rest of this guide.
|
||||
|
||||
|
@@ -0,0 +1,162 @@
|
||||
# Implement a Custom Retriever
|
||||
|
||||
In this walkthrough, you will implement a simple custom retriever in LangChain using a simple dot product distance lookup.
|
||||
|
||||
All retrievers inherit from the `BaseRetriever` class and override the following abstract methods:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from langchain.schema import Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
The `_get_relevant_documents` and async `_get_relevant_documents` methods can be implemented however you see fit. The `run_manager` is useful if your retriever calls other traceable LangChain primitives like LLMs, chains, or tools.
|
||||
|
||||
|
||||
Below, implement an example that fetches the most similar documents from a list of documents using a numpy array of embeddings.
|
||||
|
||||
|
||||
```python
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class NumpyRetriever(BaseRetriever):
|
||||
"""Retrieves documents from a numpy array."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
vectors: np.ndarray,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
num_to_return: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embeddings = embeddings or OpenAIEmbeddings()
|
||||
self.texts = texts
|
||||
self.vectors = vectors
|
||||
self.num_to_return = num_to_return
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> "NumpyRetriever":
|
||||
embeddings = embeddings or OpenAIEmbeddings()
|
||||
vectors = np.array(embeddings.embed_documents(texts))
|
||||
return cls(texts, vectors, embeddings)
|
||||
|
||||
def _get_relevant_documents_from_query_vector(
|
||||
self, vector_query: np.ndarray
|
||||
) -> List[Document]:
|
||||
dot_product = np.dot(self.vectors, vector_query)
|
||||
# Get the indices of the min 5 documents
|
||||
indices = np.argpartition(
|
||||
dot_product, -min(self.num_to_return, len(self.vectors))
|
||||
)[-self.num_to_return :]
|
||||
# Sort indices by distance
|
||||
indices = indices[np.argsort(dot_product[indices])]
|
||||
return [
|
||||
Document(
|
||||
page_content=self.texts[idx],
|
||||
metadata={"index": idx},
|
||||
)
|
||||
for idx in indices
|
||||
]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
vector_query = np.array(self.embeddings.embed_query(query))
|
||||
return self._get_relevant_documents_from_query_vector(vector_query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
query_emb = await self.embeddings.aembed_query(query)
|
||||
return self._get_relevant_documents_from_query_vector(np.array(query_emb))
|
||||
```
|
||||
|
||||
The retriever can be instantiated through the class method `from_texts`. It embeds the texts and stores them in a numpy array. To look up documents, it embeds the query and finds the most similar documents using a simple dot product distance.
|
||||
Once the retriever is implemented, you can use it like any other retriever in LangChain.
|
||||
|
||||
|
||||
```python
|
||||
retriever = NumpyRetriever.from_texts(texts= ["hello world", "goodbye world"])
|
||||
```
|
||||
|
||||
You can then use the retriever to get relevant documents.
|
||||
|
||||
```python
|
||||
retriever.get_relevant_documents("Hi there!")
|
||||
|
||||
# [Document(page_content='hello world', metadata={'index': 0})]
|
||||
```
|
||||
|
||||
```python
|
||||
retriever.get_relevant_documents("Bye!")
|
||||
# [Document(page_content='goodbye world', metadata={'index': 1})]
|
||||
```
|
Reference in New Issue
Block a user