patch: deprecate (a)get_relevant_documents (#20477)

- `.get_relevant_documents(query)` -> `.invoke(query)`
- `.get_relevant_documents(query=query)` -> `.invoke(query)`
- `.get_relevant_documents(query, callbacks=callbacks)` ->
`.invoke(query, config={"callbacks": callbacks})`
- `.get_relevant_documents(query, **kwargs)` -> `.invoke(query,
**kwargs)`

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
ccurme
2024-04-22 11:14:53 -04:00
committed by GitHub
parent 939d113d10
commit c010ec8b71
171 changed files with 443 additions and 535 deletions

View File

@@ -314,8 +314,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
docs = self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
)
return self._reduce_tokens_below_limit(docs)
@@ -327,8 +327,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
docs = await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
)
return self._reduce_tokens_below_limit(docs)

View File

@@ -155,7 +155,7 @@ class FlareChain(Chain):
callbacks = _run_manager.get_child()
docs = []
for question in questions:
docs.extend(self.retriever.get_relevant_documents(question))
docs.extend(self.retriever.invoke(question))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,

View File

@@ -46,8 +46,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
question = inputs[self.question_key]
docs = self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
)
return self._reduce_tokens_below_limit(docs)
@@ -55,8 +55,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
question = inputs[self.question_key]
docs = await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
)
return self._reduce_tokens_below_limit(docs)

View File

@@ -218,8 +218,8 @@ class RetrievalQA(BaseRetrievalQA):
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
return self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
)
async def _aget_docs(
@@ -229,8 +229,8 @@ class RetrievalQA(BaseRetrievalQA):
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
return await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
)
@property

View File

@@ -55,7 +55,7 @@ class VectorStoreRetrieverMemory(BaseMemory):
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
docs = self.retriever.invoke(query)
return self._documents_to_memory_variables(docs)
async def aload_memory_variables(
@@ -64,7 +64,7 @@ class VectorStoreRetrieverMemory(BaseMemory):
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = await self.retriever.aget_relevant_documents(query)
docs = await self.retriever.ainvoke(query)
return self._documents_to_memory_variables(docs)
def _form_documents(

View File

@@ -41,8 +41,8 @@ class ContextualCompressionRetriever(BaseRetriever):
Returns:
Sequence of relevant documents
"""
docs = self.base_retriever.get_relevant_documents(
query, callbacks=run_manager.get_child(), **kwargs
docs = self.base_retriever.invoke(
query, config={"callbacks": run_manager.get_child()}, **kwargs
)
if docs:
compressed_docs = self.base_compressor.compress_documents(
@@ -67,8 +67,8 @@ class ContextualCompressionRetriever(BaseRetriever):
Returns:
List of relevant documents
"""
docs = await self.base_retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child(), **kwargs
docs = await self.base_retriever.ainvoke(
query, config={"callbacks": run_manager.get_child()}, **kwargs
)
if docs:
compressed_docs = await self.base_compressor.acompress_documents(

View File

@@ -72,8 +72,11 @@ class MergerRetriever(BaseRetriever):
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
retriever.invoke(
query,
config={
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
},
)
for i, retriever in enumerate(self.retrievers)
]
@@ -104,8 +107,11 @@ class MergerRetriever(BaseRetriever):
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*(
retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
retriever.ainvoke(
query,
config={
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
},
)
for i, retriever in enumerate(self.retrievers)
)

View File

@@ -136,8 +136,8 @@ class MultiQueryRetriever(BaseRetriever):
"""
document_lists = await asyncio.gather(
*(
self.retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child()
self.retriever.ainvoke(
query, config={"callbacks": run_manager.get_child()}
)
for query in queries
)
@@ -196,8 +196,8 @@ class MultiQueryRetriever(BaseRetriever):
"""
documents = []
for query in queries:
docs = self.retriever.get_relevant_documents(
query, callbacks=run_manager.get_child()
docs = self.retriever.invoke(
query, config={"callbacks": run_manager.get_child()}
)
documents.extend(docs)
return documents

View File

@@ -74,8 +74,8 @@ class RePhraseQueryRetriever(BaseRetriever):
response = self.llm_chain(query, callbacks=run_manager.get_child())
re_phrased_question = response["text"]
logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.get_relevant_documents(
re_phrased_question, callbacks=run_manager.get_child()
docs = self.retriever.invoke(
re_phrased_question, config={"callbacks": run_manager.get_child()}
)
return docs

View File

@@ -21,6 +21,6 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
base_compressor=base_compressor, base_retriever=base_retriever
)
actual = retriever.get_relevant_documents("Tell me about the Celtics")
actual = retriever.invoke("Tell me about the Celtics")
assert len(actual) == 2
assert texts[-1] not in [d.page_content for d in actual]

View File

@@ -27,7 +27,7 @@ def test_merger_retriever_get_relevant_docs() -> None:
# The Lord of the Retrievers.
lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
actual = lotr.get_relevant_documents("Tell me about the Celtics")
actual = lotr.invoke("Tell me about the Celtics")
assert len(actual) == 2
assert texts_group_a[0] in [d.page_content for d in actual]
assert texts_group_b[1] in [d.page_content for d in actual]

View File

@@ -25,7 +25,7 @@ def test_long_context_reorder() -> None:
search_kwargs={"k": 10}
)
reordering = LongContextReorder()
docs = retriever.get_relevant_documents("Tell me about the Celtics")
docs = retriever.invoke("Tell me about the Celtics")
actual = reordering.transform_documents(docs)
# First 2 and Last 2 elements must contain the most relevant

View File

@@ -21,7 +21,7 @@ def test_ensemble_retriever_get_relevant_docs() -> None:
ensemble_retriever = EnsembleRetriever( # type: ignore[call-arg]
retrievers=[dummy_retriever, dummy_retriever]
)
docs = ensemble_retriever.get_relevant_documents("I like apples")
docs = ensemble_retriever.invoke("I like apples")
assert len(docs) == 1
@@ -75,5 +75,5 @@ def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
weights=[0.6, 0.3, 0.1],
)
docs = ensemble_retriever.get_relevant_documents("I like apples")
docs = ensemble_retriever.invoke("I like apples")
assert len(docs) == 3

View File

@@ -129,11 +129,11 @@ async def test_aget_salient_docs(
assert doc in want
def test_get_relevant_documents(
def test_invoke(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
relevant_documents = time_weighted_retriever.invoke(query)
want = [(doc, 0.5) for doc in _get_example_memories()]
assert isinstance(relevant_documents, list)
assert len(relevant_documents) == len(want)
@@ -147,11 +147,11 @@ def test_get_relevant_documents(
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now
async def test_aget_relevant_documents(
async def test_ainvoke(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = await time_weighted_retriever.aget_relevant_documents(query)
relevant_documents = await time_weighted_retriever.ainvoke(query)
want = [(doc, 0.5) for doc in _get_example_memories()]
assert isinstance(relevant_documents, list)
assert len(relevant_documents) == len(want)