mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 22:04:37 +00:00
add return source docs (#1515)
This commit is contained in:
parent
064741db58
commit
8f21605d71
@ -32,6 +32,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
input_docs_key: str = "docs" #: :meta private:
|
input_docs_key: str = "docs" #: :meta private:
|
||||||
answer_key: str = "answer" #: :meta private:
|
answer_key: str = "answer" #: :meta private:
|
||||||
sources_answer_key: str = "sources" #: :meta private:
|
sources_answer_key: str = "sources" #: :meta private:
|
||||||
|
return_source_documents: bool = False
|
||||||
|
"""Return the source documents."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
@ -95,7 +97,10 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
return [self.answer_key, self.sources_answer_key]
|
_output_keys = [self.answer_key, self.sources_answer_key]
|
||||||
|
if self.return_source_documents:
|
||||||
|
_output_keys = _output_keys + ["source_documents"]
|
||||||
|
return _output_keys
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_naming(cls, values: Dict) -> Dict:
|
def validate_naming(cls, values: Dict) -> Dict:
|
||||||
@ -108,14 +113,20 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
"""Get docs to run questioning over."""
|
"""Get docs to run questioning over."""
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
docs = self._get_docs(inputs)
|
docs = self._get_docs(inputs)
|
||||||
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
|
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
|
||||||
if "SOURCES: " in answer:
|
if "SOURCES: " in answer:
|
||||||
answer, sources = answer.split("SOURCES: ")
|
answer, sources = answer.split("SOURCES: ")
|
||||||
else:
|
else:
|
||||||
sources = ""
|
sources = ""
|
||||||
return {self.answer_key: answer, self.sources_answer_key: sources}
|
result: Dict[str, Any] = {
|
||||||
|
self.answer_key: answer,
|
||||||
|
self.sources_answer_key: sources,
|
||||||
|
}
|
||||||
|
if self.return_source_documents:
|
||||||
|
result["source_documents"] = docs
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||||
|
Loading…
Reference in New Issue
Block a user