From 3925071dd61363db0d483412c9e64bc5a7c60307 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 12 Feb 2024 22:52:07 -0800 Subject: [PATCH] =?UTF-8?q?langchain[patch],=20templates[patch]:=20fix=20m?= =?UTF-8?q?ulti=20query=20retriever,=20web=20re=E2=80=A6=20(#17434)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …search retriever Fixes #17352 --- .../langchain/retrievers/multi_query.py | 33 +++++++------------ .../langchain/retrievers/web_research.py | 19 +++-------- .../retrievers/test_web_research.py | 2 +- .../rag_ollama_multi_query/chain.py | 30 ++--------------- 4 files changed, 20 insertions(+), 64 deletions(-) diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 7d60b221514..ca7e731c51a 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -1,39 +1,28 @@ import asyncio import logging -from typing import List, Sequence +from typing import List, Optional, Sequence from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.language_models import BaseLLM +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.retrievers import BaseRetriever from langchain.chains.llm import LLMChain -from langchain.output_parsers.pydantic import PydanticOutputParser logger = logging.getLogger(__name__) -class LineList(BaseModel): - """List of lines.""" - - lines: List[str] = Field(description="Lines of text") - """List of lines.""" - - -class LineListOutputParser(PydanticOutputParser): +class LineListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of lines.""" - def __init__(self) -> None: - super().__init__(pydantic_object=LineList) - - def parse(self, text: str) -> LineList: + def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") - return LineList(lines=lines) + return lines # Default prompt @@ -63,6 +52,7 @@ class MultiQueryRetriever(BaseRetriever): llm_chain: LLMChain verbose: bool = True parser_key: str = "lines" + """DEPRECATED. parser_key is no longer used and should not be specified.""" include_original: bool = False """Whether to include the original query in the list of generated queries.""" @@ -70,9 +60,9 @@ class MultiQueryRetriever(BaseRetriever): def from_llm( cls, retriever: BaseRetriever, - llm: BaseLLM, + llm: BaseLanguageModel, prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, - parser_key: str = "lines", + parser_key: Optional[str] = None, include_original: bool = False, ) -> "MultiQueryRetriever": """Initialize from llm using default template. @@ -91,7 +81,6 @@ class MultiQueryRetriever(BaseRetriever): return cls( retriever=retriever, llm_chain=llm_chain, - parser_key=parser_key, include_original=include_original, ) @@ -129,7 +118,7 @@ class MultiQueryRetriever(BaseRetriever): response = await self.llm_chain.acall( inputs={"question": question}, callbacks=run_manager.get_child() ) - lines = getattr(response["text"], self.parser_key, []) + lines = response["text"] if self.verbose: logger.info(f"Generated queries: {lines}") return lines @@ -189,7 +178,7 @@ class MultiQueryRetriever(BaseRetriever): response = self.llm_chain( {"question": question}, callbacks=run_manager.get_child() ) - lines = getattr(response["text"], self.parser_key, []) + lines = response["text"] if self.verbose: logger.info(f"Generated queries: {lines}") return lines diff --git a/libs/langchain/langchain/retrievers/web_research.py b/libs/langchain/langchain/retrievers/web_research.py index 149ef247af7..b992490f617 100644 --- a/libs/langchain/langchain/retrievers/web_research.py +++ b/libs/langchain/langchain/retrievers/web_research.py @@ -12,6 +12,7 @@ from langchain_core.callbacks import ( ) from langchain_core.documents import Document from langchain_core.language_models import BaseLLM +from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.retrievers import BaseRetriever @@ -19,7 +20,6 @@ from langchain_core.vectorstores import VectorStore from langchain.chains import LLMChain from langchain.chains.prompt_selector import ConditionalPromptSelector -from langchain.output_parsers.pydantic import PydanticOutputParser from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter logger = logging.getLogger(__name__) @@ -50,21 +50,12 @@ should have a question mark at the end: {question}""", ) -class LineList(BaseModel): - """List of questions.""" - - lines: List[str] = Field(description="Questions") - - -class QuestionListOutputParser(PydanticOutputParser): +class QuestionListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of numbered questions.""" - def __init__(self) -> None: - super().__init__(pydantic_object=LineList) - - def parse(self, text: str) -> LineList: + def parse(self, text: str) -> List[str]: lines = re.findall(r"\d+\..*?(?:\n|$)", text) - return LineList(lines=lines) + return lines class WebResearchRetriever(BaseRetriever): @@ -176,7 +167,7 @@ class WebResearchRetriever(BaseRetriever): logger.info("Generating questions for Google Search ...") result = self.llm_chain({"question": query}) logger.info(f"Questions for Google Search (raw): {result}") - questions = getattr(result["text"], "lines", []) + questions = result["text"] logger.info(f"Questions for Google Search: {questions}") # Get urls diff --git a/libs/langchain/tests/unit_tests/retrievers/test_web_research.py b/libs/langchain/tests/unit_tests/retrievers/test_web_research.py index a052e59b722..29878dd3b4b 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_web_research.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_web_research.py @@ -33,4 +33,4 @@ from langchain.retrievers.web_research import QuestionListOutputParser def test_list_output_parser(text: str, expected: List[str]) -> None: parser = QuestionListOutputParser() result = parser.parse(text) - assert result.lines == expected + assert result == expected diff --git a/templates/rag-ollama-multi-query/rag_ollama_multi_query/chain.py b/templates/rag-ollama-multi-query/rag_ollama_multi_query/chain.py index dea1ea5fe53..1bc81584532 100644 --- a/templates/rag-ollama-multi-query/rag_ollama_multi_query/chain.py +++ b/templates/rag-ollama-multi-query/rag_ollama_multi_query/chain.py @@ -1,7 +1,3 @@ -from typing import List - -from langchain.chains import LLMChain -from langchain.output_parsers import PydanticOutputParser from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.chat_models import ChatOllama, ChatOpenAI @@ -10,7 +6,7 @@ from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import Chroma from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import RunnableParallel, RunnablePassthrough # Load @@ -29,23 +25,6 @@ vectorstore = Chroma.from_documents( ) -# Output parser will split the LLM result into a list of queries -class LineList(BaseModel): - # "lines" is the key (attribute name) of the parsed output - lines: List[str] = Field(description="Lines of text") - - -class LineListOutputParser(PydanticOutputParser): - def __init__(self) -> None: - super().__init__(pydantic_object=LineList) - - def parse(self, text: str) -> LineList: - lines = text.strip().split("\n") - return LineList(lines=lines) - - -output_parser = LineListOutputParser() - QUERY_PROMPT = PromptTemplate( input_variables=["question"], template="""You are an AI language model assistant. Your task is to generate five @@ -60,12 +39,9 @@ QUERY_PROMPT = PromptTemplate( ollama_llm = "zephyr" llm = ChatOllama(model=ollama_llm) -# Chain -llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT, output_parser=output_parser) - # Run -retriever = MultiQueryRetriever( - retriever=vectorstore.as_retriever(), llm_chain=llm_chain, parser_key="lines" +retriever = MultiQueryRetriever.from_llm( + vectorstore.as_retriever(), llm, prompt=QUERY_PROMPT ) # "lines" is the key (attribute name) of the parsed output # RAG prompt