core[minor]: Add async aformat_document method (#20037)

This commit is contained in:
Christophe Bornet
2024-04-05 16:29:53 +02:00
committed by GitHub
parent 927793d088
commit 7e5c1905b1
5 changed files with 63 additions and 19 deletions

View File

@@ -24,7 +24,11 @@ from multiple components and prompt values. Prompt classes and functions make co
SystemMessagePromptTemplate
""" # noqa: E501
from langchain_core.prompts.base import BasePromptTemplate, format_document
from langchain_core.prompts.base import (
BasePromptTemplate,
aformat_document,
format_document,
)
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
@@ -67,6 +71,7 @@ __all__ = [
"SystemMessagePromptTemplate",
"load_prompt",
"format_document",
"aformat_document",
"check_valid_template",
"get_template_variables",
"jinja2_formatter",

View File

@@ -251,6 +251,21 @@ class BasePromptTemplate(
raise ValueError(f"{save_path} must be json or yaml")
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
return {k: base_info[k] for k in prompt.input_variables}
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.
@@ -285,16 +300,30 @@ def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
return prompt.format(**_get_document_info(doc, prompt))
async def aformat_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
"""
return await prompt.aformat(**_get_document_info(doc, prompt))

View File

@@ -10,6 +10,7 @@ EXPECTED_ALL = [
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"format_document",
"aformat_document",
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"PipelinePromptTemplate",

View File

@@ -4,7 +4,12 @@ from typing import Optional
from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
@@ -39,7 +44,7 @@ async def _aget_relevant_documents(
) -> str:
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
[await aformat_document(doc, document_prompt) for doc in docs]
)

View File

@@ -4,7 +4,7 @@ from typing import Any, List
import pytest
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, format_document
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
from langchain.chains.combine_documents.reduce import (
collapse_docs,
@@ -119,7 +119,7 @@ def test__collapse_docs_metadata() -> None:
assert output == expected_output
def test_format_doc_with_metadata() -> None:
async def test_format_doc_with_metadata() -> None:
"""Test format doc on a valid document."""
doc = Document(page_content="foo", metadata={"bar": "baz"})
prompt = PromptTemplate(
@@ -128,9 +128,11 @@ def test_format_doc_with_metadata() -> None:
expected_output = "foo, baz"
output = format_document(doc, prompt)
assert output == expected_output
output = await aformat_document(doc, prompt)
assert output == expected_output
def test_format_doc_missing_metadata() -> None:
async def test_format_doc_missing_metadata() -> None:
"""Test format doc on a document with missing metadata."""
doc = Document(page_content="foo")
prompt = PromptTemplate(
@@ -138,3 +140,5 @@ def test_format_doc_missing_metadata() -> None:
)
with pytest.raises(ValueError):
format_document(doc, prompt)
with pytest.raises(ValueError):
await aformat_document(doc, prompt)