mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 11:30:37 +00:00
patch: deprecate (a)get_relevant_documents (#20477)
- `.get_relevant_documents(query)` -> `.invoke(query)` - `.get_relevant_documents(query=query)` -> `.invoke(query)` - `.get_relevant_documents(query, callbacks=callbacks)` -> `.invoke(query, config={"callbacks": callbacks})` - `.get_relevant_documents(query, **kwargs)` -> `.invoke(query, **kwargs)` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -314,8 +314,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
docs = self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@@ -327,8 +327,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
docs = await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
docs = await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
|
@@ -155,7 +155,7 @@ class FlareChain(Chain):
|
||||
callbacks = _run_manager.get_child()
|
||||
docs = []
|
||||
for question in questions:
|
||||
docs.extend(self.retriever.get_relevant_documents(question))
|
||||
docs.extend(self.retriever.invoke(question))
|
||||
context = "\n\n".join(d.page_content for d in docs)
|
||||
result = self.response_chain.predict(
|
||||
user_input=user_input,
|
||||
|
@@ -46,8 +46,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
docs = self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@@ -55,8 +55,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
docs = await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
|
@@ -218,8 +218,8 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
return self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
return self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
|
||||
async def _aget_docs(
|
||||
@@ -229,8 +229,8 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
return await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
return await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
|
||||
@property
|
||||
|
@@ -55,7 +55,7 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
docs = self.retriever.invoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
async def aload_memory_variables(
|
||||
@@ -64,7 +64,7 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = await self.retriever.aget_relevant_documents(query)
|
||||
docs = await self.retriever.ainvoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
def _form_documents(
|
||||
|
@@ -41,8 +41,8 @@ class ContextualCompressionRetriever(BaseRetriever):
|
||||
Returns:
|
||||
Sequence of relevant documents
|
||||
"""
|
||||
docs = self.base_retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(), **kwargs
|
||||
docs = self.base_retriever.invoke(
|
||||
query, config={"callbacks": run_manager.get_child()}, **kwargs
|
||||
)
|
||||
if docs:
|
||||
compressed_docs = self.base_compressor.compress_documents(
|
||||
@@ -67,8 +67,8 @@ class ContextualCompressionRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
docs = await self.base_retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(), **kwargs
|
||||
docs = await self.base_retriever.ainvoke(
|
||||
query, config={"callbacks": run_manager.get_child()}, **kwargs
|
||||
)
|
||||
if docs:
|
||||
compressed_docs = await self.base_compressor.acompress_documents(
|
||||
|
@@ -72,8 +72,11 @@ class MergerRetriever(BaseRetriever):
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
|
||||
retriever.invoke(
|
||||
query,
|
||||
config={
|
||||
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
|
||||
},
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
@@ -104,8 +107,11 @@ class MergerRetriever(BaseRetriever):
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = await asyncio.gather(
|
||||
*(
|
||||
retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
|
||||
retriever.ainvoke(
|
||||
query,
|
||||
config={
|
||||
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
|
||||
},
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
)
|
||||
|
@@ -136,8 +136,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
"""
|
||||
document_lists = await asyncio.gather(
|
||||
*(
|
||||
self.retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child()
|
||||
self.retriever.ainvoke(
|
||||
query, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
for query in queries
|
||||
)
|
||||
@@ -196,8 +196,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
"""
|
||||
documents = []
|
||||
for query in queries:
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child()
|
||||
docs = self.retriever.invoke(
|
||||
query, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
documents.extend(docs)
|
||||
return documents
|
||||
|
@@ -74,8 +74,8 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
response = self.llm_chain(query, callbacks=run_manager.get_child())
|
||||
re_phrased_question = response["text"]
|
||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||
docs = self.retriever.get_relevant_documents(
|
||||
re_phrased_question, callbacks=run_manager.get_child()
|
||||
docs = self.retriever.invoke(
|
||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
return docs
|
||||
|
||||
|
@@ -21,6 +21,6 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||
base_compressor=base_compressor, base_retriever=base_retriever
|
||||
)
|
||||
|
||||
actual = retriever.get_relevant_documents("Tell me about the Celtics")
|
||||
actual = retriever.invoke("Tell me about the Celtics")
|
||||
assert len(actual) == 2
|
||||
assert texts[-1] not in [d.page_content for d in actual]
|
||||
|
@@ -27,7 +27,7 @@ def test_merger_retriever_get_relevant_docs() -> None:
|
||||
# The Lord of the Retrievers.
|
||||
lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
|
||||
|
||||
actual = lotr.get_relevant_documents("Tell me about the Celtics")
|
||||
actual = lotr.invoke("Tell me about the Celtics")
|
||||
assert len(actual) == 2
|
||||
assert texts_group_a[0] in [d.page_content for d in actual]
|
||||
assert texts_group_b[1] in [d.page_content for d in actual]
|
||||
|
@@ -25,7 +25,7 @@ def test_long_context_reorder() -> None:
|
||||
search_kwargs={"k": 10}
|
||||
)
|
||||
reordering = LongContextReorder()
|
||||
docs = retriever.get_relevant_documents("Tell me about the Celtics")
|
||||
docs = retriever.invoke("Tell me about the Celtics")
|
||||
actual = reordering.transform_documents(docs)
|
||||
|
||||
# First 2 and Last 2 elements must contain the most relevant
|
||||
|
@@ -21,7 +21,7 @@ def test_ensemble_retriever_get_relevant_docs() -> None:
|
||||
ensemble_retriever = EnsembleRetriever( # type: ignore[call-arg]
|
||||
retrievers=[dummy_retriever, dummy_retriever]
|
||||
)
|
||||
docs = ensemble_retriever.get_relevant_documents("I like apples")
|
||||
docs = ensemble_retriever.invoke("I like apples")
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
@@ -75,5 +75,5 @@ def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None
|
||||
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
|
||||
weights=[0.6, 0.3, 0.1],
|
||||
)
|
||||
docs = ensemble_retriever.get_relevant_documents("I like apples")
|
||||
docs = ensemble_retriever.invoke("I like apples")
|
||||
assert len(docs) == 3
|
||||
|
@@ -129,11 +129,11 @@ async def test_aget_salient_docs(
|
||||
assert doc in want
|
||||
|
||||
|
||||
def test_get_relevant_documents(
|
||||
def test_invoke(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
|
||||
relevant_documents = time_weighted_retriever.invoke(query)
|
||||
want = [(doc, 0.5) for doc in _get_example_memories()]
|
||||
assert isinstance(relevant_documents, list)
|
||||
assert len(relevant_documents) == len(want)
|
||||
@@ -147,11 +147,11 @@ def test_get_relevant_documents(
|
||||
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now
|
||||
|
||||
|
||||
async def test_aget_relevant_documents(
|
||||
async def test_ainvoke(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
relevant_documents = await time_weighted_retriever.aget_relevant_documents(query)
|
||||
relevant_documents = await time_weighted_retriever.ainvoke(query)
|
||||
want = [(doc, 0.5) for doc in _get_example_memories()]
|
||||
assert isinstance(relevant_documents, list)
|
||||
assert len(relevant_documents) == len(want)
|
||||
|
Reference in New Issue
Block a user