From 0f7b8adddf379ed2eca54b0f24b6a80996e8e0fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thin=20red=20line=20=E6=9C=AA=E6=9D=A5=E4=BA=A7=E5=93=81?= =?UTF-8?q?=E7=BB=8F=E7=90=86?= <66343787+jiru1997@users.noreply.github.com> Date: Mon, 19 Aug 2024 06:33:19 -0700 Subject: [PATCH] fix issue: cannot use document_variable_name to override context in create_stuff_documents_chain (#25531) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …he prompt in the create_stuff_documents_chain Thank you for contributing to LangChain! - [ ] **PR title**: "langchain:add document_variable_name in the function _validate_prompt in create_stuff_documents_chain" - [ ] **PR message**: - **Description:** add document_variable_name in the function _validate_prompt in create_stuff_documents_chain - **Issue:** according to the description of create_stuff_documents_chain function, the parameter document_variable_name can be used to override the "context" in the prompt, but in the function, _validate_prompt it still use DOCUMENTS_KEY to check if it is a valid prompt, the value of DOCUMENTS_KEY is always "context", so even through the user use document_variable_name to override it, the code still tries to check if "context" is in the prompt, and finally it reports error. so I use document_variable_name to replace DOCUMENTS_KEY, the default value of document_variable_name is "context" which is same as DOCUMENTS_KEY, but it can be override by users. - **Dependencies:** none - **Twitter handle:** https://x.com/xjr199703 - [ ] **Add tests and docs**: none - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Chester Curme --- libs/langchain/langchain/chains/combine_documents/base.py | 8 ++++---- .../langchain/langchain/chains/combine_documents/stuff.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 90e965996de..00b6002da0e 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -22,11 +22,11 @@ DOCUMENTS_KEY = "context" DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}") -def _validate_prompt(prompt: BasePromptTemplate) -> None: - if DOCUMENTS_KEY not in prompt.input_variables: +def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None: + if document_variable_name not in prompt.input_variables: raise ValueError( - f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt " - f"with input variables: {prompt.input_variables}" + f"Prompt must accept {document_variable_name} as an input variable. " + f"Received prompt with input variables: {prompt.input_variables}" ) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 5ffd86c9718..cdecec0f40b 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -76,7 +76,7 @@ def create_stuff_documents_chain( chain.invoke({"context": docs}) """ # noqa: E501 - _validate_prompt(prompt) + _validate_prompt(prompt, document_variable_name) _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT _output_parser = output_parser or StrOutputParser()