From 0c2f7d8da18d20bb670eee6dfb935c161cf06e36 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 5 Jan 2023 09:33:59 -0800 Subject: [PATCH] changes to qa chain (#543) --- langchain/chains/qa_with_sources/__init__.py | 3 ++- langchain/chains/qa_with_sources/base.py | 16 ++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 56bb9bcd2e2..b6e7b51f3a4 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -25,6 +25,7 @@ class LoadingCallable(Protocol): def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, + document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT, document_variable_name: str = "summaries", verbose: Optional[bool] = None, **kwargs: Any, @@ -33,7 +34,7 @@ def _load_stuff_chain( return StuffDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, - document_prompt=stuff_prompt.EXAMPLE_PROMPT, + document_prompt=document_prompt, verbose=verbose, **kwargs, ) diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 9eabb9f56e3..96b02d07201 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -81,18 +81,6 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): """ return [self.answer_key, self.sources_answer_key] - @root_validator(pre=True) - def validate_question_chain(cls, values: Dict) -> Dict: - """Validate question chain.""" - llm_question_chain = values["combine_document_chain"].llm_chain - if len(llm_question_chain.input_keys) != 2: - raise ValueError( - f"The llm_question_chain should have two inputs: a content key " - f"(the first one) and a question key (the second one). Got " - f"{llm_question_chain.input_keys}." - ) - return values - @root_validator() def validate_combine_chain_can_be_constructed(cls, values: Dict) -> Dict: """Validate that the combine chain can be constructed.""" @@ -107,8 +95,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: docs = self._get_docs(inputs) answer, _ = self.combine_document_chain.combine_docs(docs, **inputs) - if "\nSOURCES: " in answer: - answer, sources = answer.split("\nSOURCES: ") + if "SOURCES: " in answer: + answer, sources = answer.split("SOURCES: ") else: sources = "" return {self.answer_key: answer, self.sources_answer_key: sources}