mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 21:23:32 +00:00
update keys for chain (#5164)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
aca8cb5fba
commit
1b58460fe3
@ -100,6 +100,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
extra_keys = [
|
||||
k for k in self.llm_chain.input_keys if k != self.document_variable_name
|
||||
]
|
||||
return super().input_keys + extra_keys
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
"""Construct inputs from kwargs and docs.
|
||||
|
||||
|
@ -9,8 +9,10 @@ from langchain.chains.combine_documents.reduce import (
|
||||
_collapse_docs,
|
||||
_split_list_of_docs,
|
||||
)
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import format_document
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def _fake_docs_len_func(docs: List[Document]) -> int:
|
||||
@ -21,6 +23,11 @@ def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return "".join([d.page_content for d in docs])
|
||||
|
||||
|
||||
def test_multiple_input_keys() -> None:
|
||||
chain = load_qa_with_sources_chain(FakeLLM(), chain_type="stuff")
|
||||
assert chain.input_keys == ["input_documents", "question"]
|
||||
|
||||
|
||||
def test__split_list_long_single_doc() -> None:
|
||||
"""Test splitting of a long single doc."""
|
||||
docs = [Document(page_content="foo" * 100)]
|
||||
|
Loading…
Reference in New Issue
Block a user