From 1b58460fe3d618cacd53b6abb4a2a0145f78eaaf Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 11 Aug 2023 16:25:13 -0700 Subject: [PATCH] update keys for chain (#5164) Co-authored-by: Bagatur --- libs/langchain/langchain/chains/combine_documents/stuff.py | 7 +++++++ .../tests/unit_tests/chains/test_combine_documents.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 2b113c7ab46..996a5ef7138 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -100,6 +100,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): ) return values + @property + def input_keys(self) -> List[str]: + extra_keys = [ + k for k in self.llm_chain.input_keys if k != self.document_variable_name + ] + return super().input_keys + extra_keys + def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: """Construct inputs from kwargs and docs. diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index df2212588e6..a970c33cd4d 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -9,8 +9,10 @@ from langchain.chains.combine_documents.reduce import ( _collapse_docs, _split_list_of_docs, ) +from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.docstore.document import Document from langchain.schema import format_document +from tests.unit_tests.llms.fake_llm import FakeLLM def _fake_docs_len_func(docs: List[Document]) -> int: @@ -21,6 +23,11 @@ def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str: return "".join([d.page_content for d in docs]) +def test_multiple_input_keys() -> None: + chain = load_qa_with_sources_chain(FakeLLM(), chain_type="stuff") + assert chain.input_keys == ["input_documents", "question"] + + def test__split_list_long_single_doc() -> None: """Test splitting of a long single doc.""" docs = [Document(page_content="foo" * 100)]