mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
add return source docs (#1515)
This commit is contained in:
parent
064741db58
commit
8f21605d71
@ -32,6 +32,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
input_docs_key: str = "docs" #: :meta private:
|
||||
answer_key: str = "answer" #: :meta private:
|
||||
sources_answer_key: str = "sources" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -95,7 +97,10 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.answer_key, self.sources_answer_key]
|
||||
_output_keys = [self.answer_key, self.sources_answer_key]
|
||||
if self.return_source_documents:
|
||||
_output_keys = _output_keys + ["source_documents"]
|
||||
return _output_keys
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_naming(cls, values: Dict) -> Dict:
|
||||
@ -108,14 +113,20 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
docs = self._get_docs(inputs)
|
||||
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
|
||||
if "SOURCES: " in answer:
|
||||
answer, sources = answer.split("SOURCES: ")
|
||||
else:
|
||||
sources = ""
|
||||
return {self.answer_key: answer, self.sources_answer_key: sources}
|
||||
result: Dict[str, Any] = {
|
||||
self.answer_key: answer,
|
||||
self.sources_answer_key: sources,
|
||||
}
|
||||
if self.return_source_documents:
|
||||
result["source_documents"] = docs
|
||||
return result
|
||||
|
||||
|
||||
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||
|
Loading…
Reference in New Issue
Block a user