mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
core[minor]: Add async aformat_document method (#20037)
This commit is contained in:
committed by
GitHub
parent
927793d088
commit
7e5c1905b1
@@ -24,7 +24,11 @@ from multiple components and prompt values. Prompt classes and functions make co
|
||||
SystemMessagePromptTemplate
|
||||
|
||||
""" # noqa: E501
|
||||
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.base import (
|
||||
BasePromptTemplate,
|
||||
aformat_document,
|
||||
format_document,
|
||||
)
|
||||
from langchain_core.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
@@ -67,6 +71,7 @@ __all__ = [
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"format_document",
|
||||
"aformat_document",
|
||||
"check_valid_template",
|
||||
"get_template_variables",
|
||||
"jinja2_formatter",
|
||||
|
@@ -251,6 +251,21 @@ class BasePromptTemplate(
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
return {k: base_info[k] for k in prompt.input_variables}
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
@@ -285,16 +300,30 @@ def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||
format_document(doc, prompt)
|
||||
>>> "Page 1: This is a joke"
|
||||
"""
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
return prompt.format(**_get_document_info(doc, prompt))
|
||||
|
||||
|
||||
async def aformat_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
First, this pulls information from the document from two sources:
|
||||
|
||||
1. `page_content`:
|
||||
This takes the information from the `document.page_content`
|
||||
and assigns it to a variable named `page_content`.
|
||||
2. metadata:
|
||||
This takes information from `document.metadata` and assigns
|
||||
it to variables of the same name.
|
||||
|
||||
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||
|
||||
Args:
|
||||
doc: Document, the page_content and metadata will be used to create
|
||||
the final string.
|
||||
prompt: BasePromptTemplate, will be used to format the page_content
|
||||
and metadata into the final string.
|
||||
|
||||
Returns:
|
||||
string of the document formatted.
|
||||
"""
|
||||
return await prompt.aformat(**_get_document_info(doc, prompt))
|
||||
|
@@ -10,6 +10,7 @@ EXPECTED_ALL = [
|
||||
"FewShotPromptWithTemplates",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"format_document",
|
||||
"aformat_document",
|
||||
"HumanMessagePromptTemplate",
|
||||
"MessagesPlaceholder",
|
||||
"PipelinePromptTemplate",
|
||||
|
@@ -4,7 +4,12 @@ from typing import Optional
|
||||
from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
aformat_document,
|
||||
format_document,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
@@ -39,7 +44,7 @@ async def _aget_relevant_documents(
|
||||
) -> str:
|
||||
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
||||
return document_separator.join(
|
||||
format_document(doc, document_prompt) for doc in docs
|
||||
[await aformat_document(doc, document_prompt) for doc in docs]
|
||||
)
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import PromptTemplate, format_document
|
||||
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
|
||||
|
||||
from langchain.chains.combine_documents.reduce import (
|
||||
collapse_docs,
|
||||
@@ -119,7 +119,7 @@ def test__collapse_docs_metadata() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_format_doc_with_metadata() -> None:
|
||||
async def test_format_doc_with_metadata() -> None:
|
||||
"""Test format doc on a valid document."""
|
||||
doc = Document(page_content="foo", metadata={"bar": "baz"})
|
||||
prompt = PromptTemplate(
|
||||
@@ -128,9 +128,11 @@ def test_format_doc_with_metadata() -> None:
|
||||
expected_output = "foo, baz"
|
||||
output = format_document(doc, prompt)
|
||||
assert output == expected_output
|
||||
output = await aformat_document(doc, prompt)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_format_doc_missing_metadata() -> None:
|
||||
async def test_format_doc_missing_metadata() -> None:
|
||||
"""Test format doc on a document with missing metadata."""
|
||||
doc = Document(page_content="foo")
|
||||
prompt = PromptTemplate(
|
||||
@@ -138,3 +140,5 @@ def test_format_doc_missing_metadata() -> None:
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
format_document(doc, prompt)
|
||||
with pytest.raises(ValueError):
|
||||
await aformat_document(doc, prompt)
|
||||
|
Reference in New Issue
Block a user