diff --git a/libs/langchain/langchain/tools/retriever.py b/libs/langchain/langchain/tools/retriever.py index 1d2cc7bc272..b999c74b6cb 100644 --- a/libs/langchain/langchain/tools/retriever.py +++ b/libs/langchain/langchain/tools/retriever.py @@ -1,6 +1,9 @@ from functools import partial from typing import Optional +from langchain_core.callbacks.manager import ( + Callbacks, +) from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.retrievers import BaseRetriever @@ -19,8 +22,9 @@ def _get_relevant_documents( retriever: BaseRetriever, document_prompt: BasePromptTemplate, document_separator: str, + callbacks: Callbacks = None, ) -> str: - docs = retriever.get_relevant_documents(query) + docs = retriever.get_relevant_documents(query, callbacks=callbacks) return document_separator.join( format_document(doc, document_prompt) for doc in docs ) @@ -31,8 +35,9 @@ async def _aget_relevant_documents( retriever: BaseRetriever, document_prompt: BasePromptTemplate, document_separator: str, + callbacks: Callbacks = None, ) -> str: - docs = await retriever.aget_relevant_documents(query) + docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) return document_separator.join( format_document(doc, document_prompt) for doc in docs )