mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 06:23:20 +00:00
Add New Retriever Interface with Callbacks (#5962)
Handle the new retriever events in a way that (I think) is entirely backwards compatible? Needs more testing for some of the chain changes and all. This creates an entire new run type, however. We could also just treat this as an event within a chain run presumably (same with memory) Adds a subclass initializer that upgrades old retriever implementations to the new schema, along with tests to ensure they work. First commit doesn't upgrade any of our retriever implementations (to show that we can pass the tests along with additional ones testing the upgrade logic). Second commit upgrades the known universe of retrievers in langchain. - [X] Add callback handling methods for retriever start/end/error (open to renaming to 'retrieval' if you want that) - [X] Update BaseRetriever schema to support callbacks - [X] Tests for upgrading old "v1" retrievers for backwards compatibility - [X] Update existing retriever implementations to implement the new interface - [X] Update calls within chains to .{a]get_relevant_documents to pass the child callback manager - [X] Update the notebooks/docs to reflect the new interface - [X] Test notebooks thoroughly Not handled: - Memory pass throughs: retrieval memory doesn't have a parent callback manager passed through the method --------- Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
@@ -71,11 +71,13 @@
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from langchain.schema import BaseRetriever\n",
|
||||
"from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun\n",
|
||||
"from langchain.utilities import GoogleSerperAPIWrapper\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.schema import Document"
|
||||
"from langchain.schema import Document\n",
|
||||
"from typing import Any"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -97,11 +99,16 @@
|
||||
" def __init__(self, search):\n",
|
||||
" self.search = search\n",
|
||||
"\n",
|
||||
" def get_relevant_documents(self, query: str):\n",
|
||||
" def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any) -> List[Document]:\n",
|
||||
" return [Document(page_content=self.search.run(query))]\n",
|
||||
"\n",
|
||||
" async def aget_relevant_documents(self, query: str):\n",
|
||||
" raise NotImplemented\n",
|
||||
" async def _aget_relevant_documents(self,\n",
|
||||
" query: str,\n",
|
||||
" *,\n",
|
||||
" run_manager: AsyncCallbackManagerForRetrieverRun,\n",
|
||||
" **kwargs: Any,\n",
|
||||
" ) -> List[Document]:\n",
|
||||
" raise NotImplementedError()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"retriever = SerperSearchRetriever(GoogleSerperAPIWrapper())"
|
||||
|
@@ -43,7 +43,6 @@
|
||||
"import openai\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.schema import BaseRetriever\n",
|
||||
"from langchain.vectorstores.azuresearch import AzureSearch"
|
||||
]
|
||||
},
|
||||
|
@@ -1,24 +1,40 @@
|
||||
The `BaseRetriever` class in LangChain is as follows:
|
||||
The public API of the `BaseRetriever` class in LangChain is as follows:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
from langchain.schema import Document
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get texts relevant for a query.
|
||||
|
||||
...
|
||||
def get_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant texts for
|
||||
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
...
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
It's that simple! The `get_relevant_documents` method can be implemented however you see fit.
|
||||
It's that simple! You can call `get_relevant_documents` or the async `get_relevant_documents` methods to retrieve documents relevant to a query, where "relevance" is defined by
|
||||
the specific retriever object you are calling.
|
||||
|
||||
Of course, we also help construct what we think useful Retrievers are. The main type of Retriever that we focus on is a Vectorstore retriever. We will focus on that for the rest of this guide.
|
||||
|
||||
|
@@ -0,0 +1,162 @@
|
||||
# Implement a Custom Retriever
|
||||
|
||||
In this walkthrough, you will implement a simple custom retriever in LangChain using a simple dot product distance lookup.
|
||||
|
||||
All retrievers inherit from the `BaseRetriever` class and override the following abstract methods:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from langchain.schema import Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
The `_get_relevant_documents` and async `_get_relevant_documents` methods can be implemented however you see fit. The `run_manager` is useful if your retriever calls other traceable LangChain primitives like LLMs, chains, or tools.
|
||||
|
||||
|
||||
Below, implement an example that fetches the most similar documents from a list of documents using a numpy array of embeddings.
|
||||
|
||||
|
||||
```python
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class NumpyRetriever(BaseRetriever):
|
||||
"""Retrieves documents from a numpy array."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
vectors: np.ndarray,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
num_to_return: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embeddings = embeddings or OpenAIEmbeddings()
|
||||
self.texts = texts
|
||||
self.vectors = vectors
|
||||
self.num_to_return = num_to_return
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> "NumpyRetriever":
|
||||
embeddings = embeddings or OpenAIEmbeddings()
|
||||
vectors = np.array(embeddings.embed_documents(texts))
|
||||
return cls(texts, vectors, embeddings)
|
||||
|
||||
def _get_relevant_documents_from_query_vector(
|
||||
self, vector_query: np.ndarray
|
||||
) -> List[Document]:
|
||||
dot_product = np.dot(self.vectors, vector_query)
|
||||
# Get the indices of the min 5 documents
|
||||
indices = np.argpartition(
|
||||
dot_product, -min(self.num_to_return, len(self.vectors))
|
||||
)[-self.num_to_return :]
|
||||
# Sort indices by distance
|
||||
indices = indices[np.argsort(dot_product[indices])]
|
||||
return [
|
||||
Document(
|
||||
page_content=self.texts[idx],
|
||||
metadata={"index": idx},
|
||||
)
|
||||
for idx in indices
|
||||
]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
vector_query = np.array(self.embeddings.embed_query(query))
|
||||
return self._get_relevant_documents_from_query_vector(vector_query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
query_emb = await self.embeddings.aembed_query(query)
|
||||
return self._get_relevant_documents_from_query_vector(np.array(query_emb))
|
||||
```
|
||||
|
||||
The retriever can be instantiated through the class method `from_texts`. It embeds the texts and stores them in a numpy array. To look up documents, it embeds the query and finds the most similar documents using a simple dot product distance.
|
||||
Once the retriever is implemented, you can use it like any other retriever in LangChain.
|
||||
|
||||
|
||||
```python
|
||||
retriever = NumpyRetriever.from_texts(texts= ["hello world", "goodbye world"])
|
||||
```
|
||||
|
||||
You can then use the retriever to get relevant documents.
|
||||
|
||||
```python
|
||||
retriever.get_relevant_documents("Hi there!")
|
||||
|
||||
# [Document(page_content='hello world', metadata={'index': 0})]
|
||||
```
|
||||
|
||||
```python
|
||||
retriever.get_relevant_documents("Bye!")
|
||||
# [Document(page_content='goodbye world', metadata={'index': 1})]
|
||||
```
|
@@ -30,6 +30,7 @@ class BaseMetadataCallbackHandler:
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
@@ -52,6 +53,7 @@ class BaseMetadataCallbackHandler:
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
@@ -86,6 +88,11 @@ class BaseMetadataCallbackHandler:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
|
@@ -1,15 +1,34 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, BaseMessage, Document, LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
@@ -144,6 +163,16 @@ class CallbackManagerMixin:
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -187,6 +216,7 @@ class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
@@ -211,6 +241,11 @@ class BaseCallbackHandler(
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
@@ -371,6 +406,36 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever end."""
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
@@ -14,6 +14,7 @@ from typing import (
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -27,6 +28,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
@@ -40,6 +42,7 @@ from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
Document,
|
||||
LLMResult,
|
||||
get_buffer_string,
|
||||
)
|
||||
@@ -899,6 +902,97 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
|
||||
"""Callback manager for retriever run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
"ignore_retriever",
|
||||
documents,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
"ignore_retriever",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForRetrieverRun(
|
||||
AsyncRunManager,
|
||||
RetrieverManagerMixin,
|
||||
):
|
||||
"""Async callback manager for retriever run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
async def on_retriever_end(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when retriever ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
"ignore_retriever",
|
||||
documents,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
"ignore_retriever",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
@@ -1077,6 +1171,36 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_start",
|
||||
"ignore_retriever",
|
||||
query,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForRetrieverRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
@@ -1313,6 +1437,36 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_start",
|
||||
"ignore_retriever",
|
||||
query,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForRetrieverRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
|
@@ -4,12 +4,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import Document, LLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -265,6 +265,65 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_error(tool_run)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever starts running."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name="Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
inputs={"query": query},
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.retriever,
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_error(retrieval_run)
|
||||
|
||||
def on_retriever_end(
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when Retriever ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
@@ -302,3 +361,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
|
@@ -180,6 +180,24 @@ class LangChainTracer(BaseTracer):
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
self._futures.add(
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for the given futures to complete."""
|
||||
futures = list(self._futures)
|
||||
|
@@ -119,6 +119,7 @@ class BaseMetadataCallbackHandler:
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
@@ -149,6 +150,7 @@ class BaseMetadataCallbackHandler:
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Chain for chatting with a vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -87,7 +88,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
return _output_keys
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(
|
||||
@@ -107,7 +114,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = self._get_docs(new_question, inputs)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
||||
else:
|
||||
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -122,7 +135,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(
|
||||
@@ -141,7 +160,14 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = await self._aget_docs(new_question, inputs)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
||||
else:
|
||||
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
||||
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -187,12 +213,30 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
docs = await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@classmethod
|
||||
@@ -253,14 +297,28 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
||||
return self.vectorstore.similarity_search(
|
||||
question, k=self.top_k_docs_for_context, **full_kwargs
|
||||
)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||
|
||||
@classmethod
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -115,7 +116,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(
|
||||
@@ -124,7 +130,14 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = self._get_docs(inputs)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = self._get_docs(inputs, run_manager=_run_manager)
|
||||
else:
|
||||
docs = self._get_docs(inputs) # type: ignore[call-arg]
|
||||
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
@@ -141,7 +154,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(
|
||||
@@ -150,7 +168,13 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = await self._aget_docs(inputs)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(inputs, run_manager=_run_manager)
|
||||
else:
|
||||
docs = await self._aget_docs(inputs) # type: ignore[call-arg]
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
@@ -180,10 +204,22 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
"""
|
||||
return [self.input_docs_key, self.question_key]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
@property
|
||||
|
@@ -1,4 +1,6 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@@ -13,7 +15,9 @@ from langchain.chains.qa_with_sources import (
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||
PROMPT as MAP_RERANK_PROMPT,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
@@ -28,7 +32,7 @@ class LoadingCallable(Protocol):
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
prompt: BasePromptTemplate = MAP_RERANK_PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
|
@@ -4,6 +4,10 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -40,12 +44,20 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
docs = await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
@@ -5,6 +5,10 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -45,14 +49,18 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||
|
||||
@root_validator()
|
||||
|
@@ -12,10 +12,12 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import (
|
||||
map_reduce_prompt,
|
||||
map_rerank_prompt,
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering.map_rerank_prompt import (
|
||||
PROMPT as MAP_RERANK_PROMPT,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
@@ -30,7 +32,7 @@ class LoadingCallable(Protocol):
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
prompt: BasePromptTemplate = MAP_RERANK_PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Chain for question-answering against a vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -94,7 +95,12 @@ class BaseRetrievalQA(Chain):
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _call(
|
||||
@@ -115,8 +121,13 @@ class BaseRetrievalQA(Chain):
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = self._get_docs(question)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = self._get_docs(question, run_manager=_run_manager)
|
||||
else:
|
||||
docs = self._get_docs(question) # type: ignore[call-arg]
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
@@ -127,7 +138,12 @@ class BaseRetrievalQA(Chain):
|
||||
return {self.output_key: answer}
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
async def _acall(
|
||||
@@ -148,8 +164,13 @@ class BaseRetrievalQA(Chain):
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = await self._aget_docs(question)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(question, run_manager=_run_manager)
|
||||
else:
|
||||
docs = await self._aget_docs(question) # type: ignore[call-arg]
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
@@ -177,11 +198,27 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
|
||||
retriever: BaseRetriever = Field(exclude=True)
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(question)
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
return self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
return await self.retriever.aget_relevant_documents(question)
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
return await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@@ -218,7 +255,13 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
@@ -231,7 +274,13 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
raise NotImplementedError("VectorDBQA does not support async")
|
||||
|
||||
@property
|
||||
|
@@ -697,7 +697,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@@ -16,7 +16,7 @@ from langchain.retrievers.metal import MetalRetriever
|
||||
from langchain.retrievers.milvus import MilvusRetriever
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
|
||||
from langchain.retrievers.pupmed import PubMedRetriever
|
||||
from langchain.retrievers.pubmed import PubMedRetriever
|
||||
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
||||
from langchain.retrievers.svm import SVMRetriever
|
||||
|
@@ -1,5 +1,9 @@
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
@@ -11,8 +15,20 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
||||
It uses all ArxivAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,13 +1,18 @@
|
||||
"""Retriever wrapper for Azure Cognitive Search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -81,7 +86,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
return response_json["value"]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
search_results = self._search(query)
|
||||
|
||||
return [
|
||||
@@ -89,7 +100,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
search_results = await self._asearch(query)
|
||||
|
||||
return [
|
||||
|
@@ -1,11 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -21,7 +25,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
results = response.json()["results"][0]["results"]
|
||||
@@ -34,7 +44,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
docs.append(Document(page_content=content, metadata=metadata))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
|
||||
if not self.aiosession:
|
||||
|
@@ -1,8 +1,13 @@
|
||||
"""Retriever that wraps a base retriever and filters the results."""
|
||||
from typing import List
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
@@ -24,7 +29,13 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@@ -33,14 +44,24 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
Returns:
|
||||
Sequence of relevant documents
|
||||
"""
|
||||
docs = self.base_retriever.get_relevant_documents(query)
|
||||
docs = self.base_retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(), **kwargs
|
||||
)
|
||||
if docs:
|
||||
compressed_docs = self.base_compressor.compress_documents(docs, query)
|
||||
compressed_docs = self.base_compressor.compress_documents(
|
||||
docs, query, callbacks=run_manager.get_child()
|
||||
)
|
||||
return list(compressed_docs)
|
||||
else:
|
||||
return []
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@@ -49,10 +70,12 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
docs = await self.base_retriever.aget_relevant_documents(query)
|
||||
docs = await self.base_retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(), **kwargs
|
||||
)
|
||||
if docs:
|
||||
compressed_docs = await self.base_compressor.acompress_documents(
|
||||
docs, query
|
||||
docs, query, callbacks=run_manager.get_child()
|
||||
)
|
||||
return list(compressed_docs)
|
||||
else:
|
||||
|
@@ -1,8 +1,12 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -23,7 +27,13 @@ class DataberryRetriever(BaseRetriever):
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.datastore_url,
|
||||
json={
|
||||
@@ -48,7 +58,13 @@ class DataberryRetriever(BaseRetriever):
|
||||
for r in data["results"]
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST",
|
||||
|
@@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
@@ -49,7 +53,12 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@@ -201,5 +210,10 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
return lc_doc
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,9 +1,11 @@
|
||||
"""Interface for retrieved document compressors."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
from inspect import signature
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
@@ -12,13 +14,19 @@ class BaseDocumentCompressor(BaseModel, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
|
||||
@abstractmethod
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
|
||||
@@ -35,12 +43,26 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents."""
|
||||
for _transformer in self.transformers:
|
||||
if isinstance(_transformer, BaseDocumentCompressor):
|
||||
documents = _transformer.compress_documents(documents, query)
|
||||
accepts_callbacks = (
|
||||
signature(_transformer.compress_documents).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
is not None
|
||||
)
|
||||
if accepts_callbacks:
|
||||
documents = _transformer.compress_documents(
|
||||
documents, query, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
documents = _transformer.compress_documents(documents, query)
|
||||
elif isinstance(_transformer, BaseDocumentTransformer):
|
||||
documents = _transformer.transform_documents(documents)
|
||||
else:
|
||||
@@ -48,12 +70,26 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
|
||||
return documents
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
for _transformer in self.transformers:
|
||||
if isinstance(_transformer, BaseDocumentCompressor):
|
||||
documents = await _transformer.acompress_documents(documents, query)
|
||||
accepts_callbacks = (
|
||||
signature(_transformer.acompress_documents).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
is not None
|
||||
)
|
||||
if accepts_callbacks:
|
||||
documents = await _transformer.acompress_documents(
|
||||
documents, query, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
documents = await _transformer.acompress_documents(documents, query)
|
||||
elif isinstance(_transformer, BaseDocumentTransformer):
|
||||
documents = await _transformer.atransform_documents(documents)
|
||||
else:
|
||||
|
@@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain.retrievers.document_compressors.chain_extract_prompt import (
|
||||
prompt_template,
|
||||
@@ -48,25 +49,33 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress page content of raw documents."""
|
||||
compressed_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
output = self.llm_chain.predict_and_parse(**_input)
|
||||
output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(Document(page_content=output, metadata=doc.metadata))
|
||||
return compressed_docs
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress page content of raw documents asynchronously."""
|
||||
outputs = await asyncio.gather(
|
||||
*[
|
||||
self.llm_chain.apredict_and_parse(**self.get_input(query, doc))
|
||||
self.llm_chain.apredict_and_parse(
|
||||
**self.get_input(query, doc), callbacks=callbacks
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
)
|
||||
|
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
||||
@@ -35,19 +36,27 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
include_doc = self.llm_chain.predict_and_parse(**_input)
|
||||
include_doc = self.llm_chain.predict_and_parse(
|
||||
**_input, callbacks=callbacks
|
||||
)
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
return filtered_docs
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents."""
|
||||
raise NotImplementedError
|
||||
|
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Sequence
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain.schema import Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -48,7 +49,10 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
return values
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
@@ -65,6 +69,9 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
return final_results
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Sequence
|
||||
import numpy as np
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.document_transformers import (
|
||||
_get_embeddings_from_stateful_docs,
|
||||
get_stateful_documents,
|
||||
@@ -44,7 +45,10 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
return values
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter documents based on similarity of their embeddings to the query."""
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
@@ -64,7 +68,10 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
return [stateful_documents[i] for i in included_idxs]
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents."""
|
||||
raise NotImplementedError
|
||||
|
@@ -1,9 +1,14 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Iterable, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
@@ -111,7 +116,13 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
return ids
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
query_dict = {"query": {"match": {"content": query}}}
|
||||
res = self.client.search(index=self.index_name, body=query_dict)
|
||||
|
||||
@@ -120,5 +131,11 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
docs.append(Document(page_content=r["_source"]["content"]))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -3,6 +3,10 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
@@ -257,7 +261,12 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
docs = r_result.get_top_k_docs(top_k)
|
||||
return docs
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
Example:
|
||||
@@ -269,5 +278,10 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
||||
|
@@ -10,6 +10,10 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@@ -51,7 +55,13 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
# calc L2 norm
|
||||
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
|
||||
@@ -73,5 +83,11 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
]
|
||||
return top_k_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -11,7 +15,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
index: Any
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
from llama_index.indices.base import BaseGPTIndex
|
||||
@@ -33,7 +43,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||
|
||||
|
||||
@@ -43,7 +59,13 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
graph: Any
|
||||
query_configs: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
from llama_index.composability.graph import (
|
||||
@@ -73,5 +95,11 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||
|
@@ -1,5 +1,9 @@
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -24,7 +28,12 @@ class MergerRetriever(BaseRetriever):
|
||||
|
||||
self.retrievers = retrievers
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get the relevant documents for a given query.
|
||||
|
||||
@@ -36,11 +45,16 @@ class MergerRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
merged_documents = self.merge_documents(query)
|
||||
merged_documents = self.merge_documents(query, run_manager)
|
||||
|
||||
return merged_documents
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously get the relevant documents for a given query.
|
||||
|
||||
@@ -52,11 +66,13 @@ class MergerRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
merged_documents = await self.amerge_documents(query)
|
||||
merged_documents = await self.amerge_documents(query, run_manager)
|
||||
|
||||
return merged_documents
|
||||
|
||||
def merge_documents(self, query: str) -> List[Document]:
|
||||
def merge_documents(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Merge the results of the retrievers.
|
||||
|
||||
@@ -69,7 +85,10 @@ class MergerRetriever(BaseRetriever):
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.get_relevant_documents(query) for retriever in self.retrievers
|
||||
retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
@@ -82,7 +101,9 @@ class MergerRetriever(BaseRetriever):
|
||||
|
||||
return merged_documents
|
||||
|
||||
async def amerge_documents(self, query: str) -> List[Document]:
|
||||
async def amerge_documents(
|
||||
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously merge the results of the retrievers.
|
||||
|
||||
@@ -95,8 +116,10 @@ class MergerRetriever(BaseRetriever):
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
await retriever.aget_relevant_documents(query)
|
||||
for retriever in self.retrievers
|
||||
await retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
|
@@ -1,5 +1,9 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -17,7 +21,13 @@ class MetalRetriever(BaseRetriever):
|
||||
self.client: Metal = client
|
||||
self.params = params or {}
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
results = self.client.search({"text": query}, **self.params)
|
||||
final_results = []
|
||||
for r in results["data"]:
|
||||
@@ -25,5 +35,11 @@ class MetalRetriever(BaseRetriever):
|
||||
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
||||
return final_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -2,6 +2,10 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
@@ -39,10 +43,24 @@ class MilvusRetriever(BaseRetriever):
|
||||
"""
|
||||
self.store.add_texts(texts, metadatas)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(query)
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -1,8 +1,12 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.output_parsers.pydantic import PydanticOutputParser
|
||||
@@ -91,7 +95,12 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
parser_key=parser_key,
|
||||
)
|
||||
|
||||
def get_relevant_documents(self, question: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get relevated documents given a user query.
|
||||
|
||||
Args:
|
||||
@@ -100,15 +109,22 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
Unique union of relevant documents from all generated queries
|
||||
"""
|
||||
queries = self.generate_queries(question)
|
||||
documents = self.retrieve_documents(queries)
|
||||
queries = self.generate_queries(query, run_manager)
|
||||
documents = self.retrieve_documents(queries, run_manager)
|
||||
unique_documents = self.unique_union(documents)
|
||||
return unique_documents
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_queries(self, question: str) -> List[str]:
|
||||
def generate_queries(
|
||||
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[str]:
|
||||
"""Generate queries based upon user input.
|
||||
|
||||
Args:
|
||||
@@ -117,13 +133,17 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List of LLM generated queries that are similar to the user input
|
||||
"""
|
||||
response = self.llm_chain({"question": question})
|
||||
response = self.llm_chain(
|
||||
{"question": question}, callbacks=run_manager.get_child()
|
||||
)
|
||||
lines = getattr(response["text"], self.parser_key, [])
|
||||
if self.verbose:
|
||||
logger.info(f"Generated queries: {lines}")
|
||||
return lines
|
||||
|
||||
def retrieve_documents(self, queries: List[str]) -> List[Document]:
|
||||
def retrieve_documents(
|
||||
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Run all LLM generated queries.
|
||||
|
||||
Args:
|
||||
@@ -134,7 +154,9 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
"""
|
||||
documents = []
|
||||
for query in queries:
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child()
|
||||
)
|
||||
documents.extend(docs)
|
||||
return documents
|
||||
|
||||
|
@@ -1,9 +1,14 @@
|
||||
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@@ -137,7 +142,13 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from pinecone_text.hybrid import hybrid_convex_scale
|
||||
|
||||
sparse_vec = self.sparse_encoder.encode_queries(query)
|
||||
@@ -162,5 +173,11 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
# return search results as json
|
||||
return final_result
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
35
langchain/retrievers/pubmed.py
Normal file
35
langchain/retrievers/pubmed.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""A retriever that uses PubMed API to retrieve documents."""
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for PubMedAPIWrapper.
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
@@ -1,18 +1,5 @@
|
||||
from typing import List
|
||||
from langchain.retrievers.pubmed import PubMedRetriever
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for PubMedAPIWrapper.
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
__all__ = [
|
||||
"PubMedRetriever",
|
||||
]
|
||||
|
@@ -1,9 +1,13 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -15,7 +19,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
page_content_key: str = "page_content"
|
||||
metadata_key: str = "metadata"
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.url, json={self.input_key: query}, headers=self.headers
|
||||
)
|
||||
@@ -27,7 +37,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
for r in result[self.response_key]
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST", self.url, headers=self.headers, json={self.input_key: query}
|
||||
|
@@ -1,11 +1,15 @@
|
||||
"""Retriever that generates and executes structured queries over its own data source."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
@@ -79,8 +83,12 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, callbacks: Callbacks = None
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
@@ -93,7 +101,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||
structured_query = cast(
|
||||
StructuredQuery,
|
||||
self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs),
|
||||
self.llm_chain.predict_and_parse(
|
||||
callbacks=run_manager.get_child(), **inputs
|
||||
),
|
||||
)
|
||||
if self.verbose:
|
||||
print(structured_query)
|
||||
@@ -110,7 +120,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun],
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@@ -10,6 +10,10 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@@ -50,7 +54,13 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from sklearn import svm
|
||||
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
@@ -87,5 +97,11 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||
return top_k_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -2,12 +2,17 @@
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -58,7 +63,13 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
query_vec = self.vectorizer.transform(
|
||||
@@ -70,5 +81,11 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
return return_docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,10 +1,15 @@
|
||||
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
||||
|
||||
import datetime
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@@ -80,7 +85,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
results[buffer_idx] = (doc, relevance)
|
||||
return results
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
current_time = datetime.datetime.now()
|
||||
docs_and_scores = {
|
||||
@@ -103,7 +114,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -1,9 +1,14 @@
|
||||
"""Wrapper for retrieving documents from Vespa."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -59,12 +64,24 @@ class VespaRetriever(BaseRetriever):
|
||||
docs.append(Document(page_content=page_content, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
body = self._query_body.copy()
|
||||
body["query"] = query
|
||||
return self._query(body)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_relevant_documents_with_filter(
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -6,6 +7,10 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
@@ -82,8 +87,13 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Look up similar documents in Weaviate."""
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
@@ -101,7 +111,12 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,9 @@
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
|
||||
@@ -11,8 +15,20 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
||||
It uses all WikipediaAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -54,8 +58,13 @@ class ZepRetriever(BaseRetriever):
|
||||
if r.message
|
||||
]
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, metadata: Optional[Dict] = None
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
@@ -69,8 +78,13 @@ class ZepRetriever(BaseRetriever):
|
||||
|
||||
return self._search_result_to_doc(results)
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, metadata: Optional[Dict] = None
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
|
@@ -2,6 +2,10 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.zilliz import Zilliz
|
||||
@@ -39,10 +43,24 @@ class ZillizRetriever(BaseRetriever):
|
||||
"""
|
||||
self.store.add_texts(texts, metadatas)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(query)
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -1,9 +1,12 @@
|
||||
"""Common schema objects."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
@@ -20,6 +23,13 @@ from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
RUN_KEY = "__run"
|
||||
|
||||
|
||||
@@ -360,30 +370,151 @@ class Document(Serializable):
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Base interface for retrievers."""
|
||||
"""Base interface for a retriever."""
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Version upgrade for old retrievers that implemented the public
|
||||
# methods directly.
|
||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.get_relevant_documents
|
||||
)
|
||||
cls._get_relevant_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aget_relevant_documents")
|
||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.aget_relevant_documents
|
||||
)
|
||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
# If a V1 retriever broke the interface and expects additional arguments
|
||||
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
|
||||
|
||||
@abstractmethod
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = self._get_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = await self._aget_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, # type: ignore[call-arg]
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
|
||||
|
@@ -20,6 +20,10 @@ from typing import (
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever
|
||||
@@ -490,7 +494,12 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.vector_search(query, k=self.k)
|
||||
elif self.search_type == "hybrid":
|
||||
@@ -501,7 +510,12 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError(
|
||||
"AzureSearchVectorStoreRetriever does not support async"
|
||||
)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Interface for vector stores."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -20,6 +21,10 @@ from typing import (
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever
|
||||
@@ -402,7 +407,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
@@ -420,7 +431,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Wrapper around Redis vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -21,6 +22,10 @@ from typing import (
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -614,7 +619,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
elif self.search_type == "similarity_limit":
|
||||
@@ -625,7 +636,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
|
@@ -1,21 +1,17 @@
|
||||
"""Wrapper around SingleStore DB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
from typing import Any, ClassVar, Collection, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
@@ -454,14 +450,26 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
|
||||
k: int = 4
|
||||
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError(
|
||||
"SingleStoreDBVectorStoreRetriever does not support async"
|
||||
)
|
||||
|
544
poetry.lock
generated
544
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
ignore_agent_: bool = False
|
||||
ignore_retriever_: bool = False
|
||||
ignore_chat_model_: bool = False
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
@@ -32,6 +33,9 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
agent_actions: int = 0
|
||||
agent_ends: int = 0
|
||||
chat_model_starts: int = 0
|
||||
retriever_starts: int = 0
|
||||
retriever_ends: int = 0
|
||||
retriever_errors: int = 0
|
||||
|
||||
|
||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
@@ -52,7 +56,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_chain_start_common(self) -> None:
|
||||
print("CHAIN START")
|
||||
("CHAIN START")
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
@@ -91,6 +95,18 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
def on_text_common(self) -> None:
|
||||
self.text += 1
|
||||
|
||||
def on_retriever_start_common(self) -> None:
|
||||
self.starts += 1
|
||||
self.retriever_starts += 1
|
||||
|
||||
def on_retriever_end_common(self) -> None:
|
||||
self.ends += 1
|
||||
self.retriever_ends += 1
|
||||
|
||||
def on_retriever_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
self.retriever_errors += 1
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake callback handler for testing."""
|
||||
@@ -110,6 +126,11 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -201,6 +222,27 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_text_common()
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retriever_start_common()
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retriever_end_common()
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
return self
|
||||
|
||||
|
@@ -141,6 +141,22 @@ def test_ignore_agent() -> None:
|
||||
assert handler2.errors == 1
|
||||
|
||||
|
||||
def test_ignore_retriever() -> None:
|
||||
"""Test the ignore retriever param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_retriever_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
run_manager = manager.on_retriever_start("")
|
||||
run_manager.on_retriever_end([])
|
||||
run_manager.on_retriever_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
assert handler2.starts == 1
|
||||
assert handler2.ends == 1
|
||||
assert handler2.errors == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
|
220
tests/unit_tests/retrievers/test_base.py
Normal file
220
tests/unit_tests/retrievers/test_base.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Test Base Retriever logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v1() -> BaseRetriever:
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Retrievers must implement abstract "
|
||||
"`_get_relevant_documents` method instead of `get_relevant_documents`",
|
||||
):
|
||||
|
||||
class FakeRetrieverV1(BaseRetriever):
|
||||
def get_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(page_content=query, metadata={"uuid": "1234"}),
|
||||
]
|
||||
|
||||
async def aget_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Async query {query}", metadata={"uuid": "1234"}
|
||||
),
|
||||
]
|
||||
|
||||
return FakeRetrieverV1() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v1_upgrade(fake_retriever_v1: BaseRetriever) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1._new_arg_supported is False
|
||||
assert fake_retriever_v1._expects_other_args is False
|
||||
results: List[Document] = fake_retriever_v1.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
assert results[0].page_content == "Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v1_upgrade_async(
|
||||
fake_retriever_v1: BaseRetriever,
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1._new_arg_supported is False
|
||||
assert fake_retriever_v1._expects_other_args is False
|
||||
results: List[Document] = await fake_retriever_v1.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
assert results[0].page_content == "Async query Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v1_with_kwargs() -> BaseRetriever:
|
||||
# Test for things like the Weaviate V1 Retriever.
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Retrievers must implement abstract "
|
||||
"`_get_relevant_documents` method instead of `get_relevant_documents`",
|
||||
):
|
||||
|
||||
class FakeRetrieverV1(BaseRetriever):
|
||||
def get_relevant_documents( # type: ignore[override]
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(page_content=query, metadata=where_filter or {}),
|
||||
]
|
||||
|
||||
async def aget_relevant_documents( # type: ignore[override]
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Async query {query}", metadata=where_filter or {}
|
||||
),
|
||||
]
|
||||
|
||||
return FakeRetrieverV1() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v1_with_kwargs_upgrade(
|
||||
fake_retriever_v1_with_kwargs: BaseRetriever,
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1_with_kwargs._new_arg_supported is False
|
||||
assert fake_retriever_v1_with_kwargs._expects_other_args is True
|
||||
results: List[Document] = fake_retriever_v1_with_kwargs.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], where_filter={"foo": "bar"}
|
||||
)
|
||||
assert results[0].page_content == "Foo"
|
||||
assert results[0].metadata == {"foo": "bar"}
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v1_with_kwargs_upgrade_async(
|
||||
fake_retriever_v1_with_kwargs: BaseRetriever,
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1_with_kwargs._new_arg_supported is False
|
||||
assert fake_retriever_v1_with_kwargs._expects_other_args is True
|
||||
results: List[
|
||||
Document
|
||||
] = await fake_retriever_v1_with_kwargs.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], where_filter={"foo": "bar"}
|
||||
)
|
||||
assert results[0].page_content == "Async query Foo"
|
||||
assert results[0].metadata == {"foo": "bar"}
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v2() -> BaseRetriever:
|
||||
class FakeRetrieverV2(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV2)
|
||||
assert run_manager is not None
|
||||
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
|
||||
if "throw_error" in kwargs:
|
||||
raise ValueError("Test error")
|
||||
return [
|
||||
Document(page_content=query, metadata=kwargs),
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV2)
|
||||
assert run_manager is not None
|
||||
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
|
||||
if "throw_error" in kwargs:
|
||||
raise ValueError("Test error")
|
||||
return [
|
||||
Document(page_content=f"Async query {query}", metadata=kwargs),
|
||||
]
|
||||
|
||||
return FakeRetrieverV2() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v2._new_arg_supported is True
|
||||
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||
assert results[0].page_content == "Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
results2 = fake_retriever_v2.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], foo="bar"
|
||||
)
|
||||
assert results2[0].metadata == {"foo": "bar"}
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
fake_retriever_v2.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], throw_error=True
|
||||
)
|
||||
assert callbacks.retriever_errors == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v2._new_arg_supported is True
|
||||
results = await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
assert results[0].page_content == "Async query Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
results2 = await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], foo="bar"
|
||||
)
|
||||
assert results2[0].metadata == {"foo": "bar"}
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], throw_error=True
|
||||
)
|
Reference in New Issue
Block a user