langchain: default to Runnable in MultiQueryRetriever (#21770)

- `llm_chain` becomes `Union[LLMChain, Runnable]`
- `.from_llm` creates a runnable

tested by verifying that docs/how_to/MultiQueryRetriever.ipynb runs
unchanged with sync/async invoke (and that it runs if we specifically
instantiate with LLMChain).
This commit is contained in:
ccurme
2024-05-21 10:01:05 -07:00
committed by GitHub
parent 8e1aeb8ad5
commit 0923136851
2 changed files with 19 additions and 22 deletions

View File

@@ -138,20 +138,10 @@
"execution_count": 5,
"id": "d9afb0ca",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/.pyenv/versions/3.10.4/envs/sandbox310/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `LLMChain` was deprecated in LangChain 0.1.17 and will be removed in 0.3.0. Use RunnableSequence, e.g., `prompt | llm` instead.\n",
" warn_deprecated(\n"
]
}
],
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"from langchain.chains import LLMChain\n",
"from langchain_core.output_parsers import BaseOutputParser\n",
"from langchain_core.prompts import PromptTemplate\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
@@ -180,7 +170,7 @@
"llm = ChatOpenAI(temperature=0)\n",
"\n",
"# Chain\n",
"llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT, output_parser=output_parser)\n",
"llm_chain = QUERY_PROMPT | llm | output_parser\n",
"\n",
"# Other inputs\n",
"question = \"What are the approaches to Task Decomposition?\""
@@ -189,14 +179,14 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "2eca2d96-8057-4ed9-873d-fa1064c09acf",
"id": "59c75c56-dbd7-4887-b9ba-0b5b21069f51",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:langchain.retrievers.multi_query:Generated queries: ['1. Can you provide insights on regression from the course material?', '2. How is regression discussed in the course content?', '3. What information does the course offer about regression analysis?', '4. What are the teachings of the course regarding regression?', '5. In what manner is regression covered in the course curriculum?']\n"
"INFO:langchain.retrievers.multi_query:Generated queries: ['1. Can you provide insights on regression from the course material?', '2. How is regression discussed in the course content?', '3. What information does the course offer about regression?', '4. In what way is regression covered in the course?', '5. What are the teachings of the course regarding regression?']\n"
]
},
{

View File

@@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain.chains.llm import LLMChain
@@ -49,7 +50,7 @@ class MultiQueryRetriever(BaseRetriever):
"""
retriever: BaseRetriever
llm_chain: LLMChain
llm_chain: Runnable
verbose: bool = True
parser_key: str = "lines"
"""DEPRECATED. parser_key is no longer used and should not be specified."""
@@ -77,7 +78,7 @@ class MultiQueryRetriever(BaseRetriever):
MultiQueryRetriever
"""
output_parser = LineListOutputParser()
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
llm_chain = prompt | llm | output_parser
return cls(
retriever=retriever,
llm_chain=llm_chain,
@@ -115,10 +116,13 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
List of LLM generated queries that are similar to the user input
"""
response = await self.llm_chain.acall(
inputs={"question": question}, callbacks=run_manager.get_child()
response = await self.llm_chain.ainvoke(
{"question": question}, config={"callbacks": run_manager.get_child()}
)
lines = response["text"]
if isinstance(self.llm_chain, LLMChain):
lines = response["text"]
else:
lines = response
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
@@ -175,10 +179,13 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
List of LLM generated queries that are similar to the user input
"""
response = self.llm_chain(
{"question": question}, callbacks=run_manager.get_child()
response = self.llm_chain.invoke(
{"question": question}, config={"callbacks": run_manager.get_child()}
)
lines = response["text"]
if isinstance(self.llm_chain, LLMChain):
lines = response["text"]
else:
lines = response
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines