mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
langchain: default to Runnable in MultiQueryRetriever (#21770)
- `llm_chain` becomes `Union[LLMChain, Runnable]` - `.from_llm` creates a runnable tested by verifying that docs/how_to/MultiQueryRetriever.ipynb runs unchanged with sync/async invoke (and that it runs if we specifically instantiate with LLMChain).
This commit is contained in:
@@ -138,20 +138,10 @@
|
||||
"execution_count": 5,
|
||||
"id": "d9afb0ca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/chestercurme/.pyenv/versions/3.10.4/envs/sandbox310/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `LLMChain` was deprecated in LangChain 0.1.17 and will be removed in 0.3.0. Use RunnableSequence, e.g., `prompt | llm` instead.\n",
|
||||
" warn_deprecated(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain_core.output_parsers import BaseOutputParser\n",
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
@@ -180,7 +170,7 @@
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"\n",
|
||||
"# Chain\n",
|
||||
"llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT, output_parser=output_parser)\n",
|
||||
"llm_chain = QUERY_PROMPT | llm | output_parser\n",
|
||||
"\n",
|
||||
"# Other inputs\n",
|
||||
"question = \"What are the approaches to Task Decomposition?\""
|
||||
@@ -189,14 +179,14 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "2eca2d96-8057-4ed9-873d-fa1064c09acf",
|
||||
"id": "59c75c56-dbd7-4887-b9ba-0b5b21069f51",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:langchain.retrievers.multi_query:Generated queries: ['1. Can you provide insights on regression from the course material?', '2. How is regression discussed in the course content?', '3. What information does the course offer about regression analysis?', '4. What are the teachings of the course regarding regression?', '5. In what manner is regression covered in the course curriculum?']\n"
|
||||
"INFO:langchain.retrievers.multi_query:Generated queries: ['1. Can you provide insights on regression from the course material?', '2. How is regression discussed in the course content?', '3. What information does the course offer about regression?', '4. In what way is regression covered in the course?', '5. What are the teachings of the course regarding regression?']\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
@@ -49,7 +50,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
verbose: bool = True
|
||||
parser_key: str = "lines"
|
||||
"""DEPRECATED. parser_key is no longer used and should not be specified."""
|
||||
@@ -77,7 +78,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
MultiQueryRetriever
|
||||
"""
|
||||
output_parser = LineListOutputParser()
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
|
||||
llm_chain = prompt | llm | output_parser
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
llm_chain=llm_chain,
|
||||
@@ -115,10 +116,13 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List of LLM generated queries that are similar to the user input
|
||||
"""
|
||||
response = await self.llm_chain.acall(
|
||||
inputs={"question": question}, callbacks=run_manager.get_child()
|
||||
response = await self.llm_chain.ainvoke(
|
||||
{"question": question}, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
lines = response["text"]
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
lines = response["text"]
|
||||
else:
|
||||
lines = response
|
||||
if self.verbose:
|
||||
logger.info(f"Generated queries: {lines}")
|
||||
return lines
|
||||
@@ -175,10 +179,13 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List of LLM generated queries that are similar to the user input
|
||||
"""
|
||||
response = self.llm_chain(
|
||||
{"question": question}, callbacks=run_manager.get_child()
|
||||
response = self.llm_chain.invoke(
|
||||
{"question": question}, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
lines = response["text"]
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
lines = response["text"]
|
||||
else:
|
||||
lines = response
|
||||
if self.verbose:
|
||||
logger.info(f"Generated queries: {lines}")
|
||||
return lines
|
||||
|
Reference in New Issue
Block a user