mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
changes to qa chain (#543)
This commit is contained in:
parent
5b4c972fc5
commit
0c2f7d8da1
@ -25,6 +25,7 @@ class LoadingCallable(Protocol):
|
|||||||
def _load_stuff_chain(
|
def _load_stuff_chain(
|
||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
|
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
|
||||||
document_variable_name: str = "summaries",
|
document_variable_name: str = "summaries",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -33,7 +34,7 @@ def _load_stuff_chain(
|
|||||||
return StuffDocumentsChain(
|
return StuffDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
document_prompt=stuff_prompt.EXAMPLE_PROMPT,
|
document_prompt=document_prompt,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -81,18 +81,6 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return [self.answer_key, self.sources_answer_key]
|
return [self.answer_key, self.sources_answer_key]
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_question_chain(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate question chain."""
|
|
||||||
llm_question_chain = values["combine_document_chain"].llm_chain
|
|
||||||
if len(llm_question_chain.input_keys) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"The llm_question_chain should have two inputs: a content key "
|
|
||||||
f"(the first one) and a question key (the second one). Got "
|
|
||||||
f"{llm_question_chain.input_keys}."
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_combine_chain_can_be_constructed(cls, values: Dict) -> Dict:
|
def validate_combine_chain_can_be_constructed(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the combine chain can be constructed."""
|
"""Validate that the combine chain can be constructed."""
|
||||||
@ -107,8 +95,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
docs = self._get_docs(inputs)
|
docs = self._get_docs(inputs)
|
||||||
answer, _ = self.combine_document_chain.combine_docs(docs, **inputs)
|
answer, _ = self.combine_document_chain.combine_docs(docs, **inputs)
|
||||||
if "\nSOURCES: " in answer:
|
if "SOURCES: " in answer:
|
||||||
answer, sources = answer.split("\nSOURCES: ")
|
answer, sources = answer.split("SOURCES: ")
|
||||||
else:
|
else:
|
||||||
sources = ""
|
sources = ""
|
||||||
return {self.answer_key: answer, self.sources_answer_key: sources}
|
return {self.answer_key: answer, self.sources_answer_key: sources}
|
||||||
|
Loading…
Reference in New Issue
Block a user