mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 09:58:44 +00:00
Factor out doc formatting and add validation (#3026)
@cnhhoang850 slightly more generic fix for #2944, works for whatever the expected metadata keys are not just `source`
This commit is contained in:
parent
3453b7457c
commit
19c85aa990
@ -7,9 +7,28 @@ from pydantic import Field
|
|||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||||
|
"""Format a document into a string based on a prompt template."""
|
||||||
|
base_info = {"page_content": doc.page_content}
|
||||||
|
base_info.update(doc.metadata)
|
||||||
|
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||||
|
if len(missing_metadata) > 0:
|
||||||
|
required_metadata = [
|
||||||
|
iv for iv in prompt.input_variables if iv != "page_content"
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
f"Document prompt requires documents to have metadata variables: "
|
||||||
|
f"{required_metadata}. Received document with missing metadata: "
|
||||||
|
f"{list(missing_metadata)}."
|
||||||
|
)
|
||||||
|
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||||
|
return prompt.format(**document_info)
|
||||||
|
|
||||||
|
|
||||||
class BaseCombineDocumentsChain(Chain, ABC):
|
class BaseCombineDocumentsChain(Chain, ABC):
|
||||||
"""Base interface for chains combining documents."""
|
"""Base interface for chains combining documents."""
|
||||||
|
|
||||||
|
@ -6,7 +6,10 @@ from typing import Any, Dict, List, Tuple
|
|||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import (
|
||||||
|
BaseCombineDocumentsChain,
|
||||||
|
format_document,
|
||||||
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
@ -116,14 +119,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return res, extra_return_dict
|
return res, extra_return_dict
|
||||||
|
|
||||||
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
||||||
base_info = {"page_content": doc.page_content}
|
return {
|
||||||
base_info.update(doc.metadata)
|
self.document_variable_name: format_document(doc, self.document_prompt),
|
||||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
|
||||||
base_inputs = {
|
|
||||||
self.document_variable_name: self.document_prompt.format(**document_info),
|
|
||||||
self.initial_response_name: res,
|
self.initial_response_name: res,
|
||||||
}
|
}
|
||||||
return base_inputs
|
|
||||||
|
|
||||||
def _construct_initial_inputs(
|
def _construct_initial_inputs(
|
||||||
self, docs: List[Document], **kwargs: Any
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
@ -4,7 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import (
|
||||||
|
BaseCombineDocumentsChain,
|
||||||
|
format_document,
|
||||||
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
@ -56,17 +59,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||||
# Get relevant information from each document.
|
|
||||||
doc_dicts = []
|
|
||||||
for doc in docs:
|
|
||||||
base_info = {"page_content": doc.page_content}
|
|
||||||
base_info.update(doc.metadata)
|
|
||||||
document_info = {
|
|
||||||
k: base_info[k] for k in self.document_prompt.input_variables
|
|
||||||
}
|
|
||||||
doc_dicts.append(document_info)
|
|
||||||
# Format each document according to the prompt
|
# Format each document according to the prompt
|
||||||
doc_strings = [self.document_prompt.format(**doc) for doc in doc_dicts]
|
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
||||||
# Join the documents together to put them in the prompt.
|
# Join the documents together to put them in the prompt.
|
||||||
inputs = {
|
inputs = {
|
||||||
k: v
|
k: v
|
||||||
|
@ -4,6 +4,8 @@ from typing import Any, List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain import PromptTemplate
|
||||||
|
from langchain.chains.combine_documents.base import format_document
|
||||||
from langchain.chains.combine_documents.map_reduce import (
|
from langchain.chains.combine_documents.map_reduce import (
|
||||||
_collapse_docs,
|
_collapse_docs,
|
||||||
_split_list_of_docs,
|
_split_list_of_docs,
|
||||||
@ -116,3 +118,24 @@ def test__collapse_docs_metadata() -> None:
|
|||||||
}
|
}
|
||||||
expected_output = Document(page_content="foobar", metadata=expected_metadata)
|
expected_output = Document(page_content="foobar", metadata=expected_metadata)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_doc_with_metadata() -> None:
|
||||||
|
"""Test format doc on a valid document."""
|
||||||
|
doc = Document(page_content="foo", metadata={"bar": "baz"})
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
|
||||||
|
)
|
||||||
|
expected_output = "foo, baz"
|
||||||
|
output = format_document(doc, prompt)
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_doc_missing_metadata() -> None:
|
||||||
|
"""Test format doc on a document with missing metadata."""
|
||||||
|
doc = Document(page_content="foo")
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
format_document(doc, prompt)
|
||||||
|
Loading…
Reference in New Issue
Block a user