mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 15:33:11 +00:00
core[minor]: Add async aformat_document method (#20037)
This commit is contained in:
committed by
GitHub
parent
927793d088
commit
7e5c1905b1
@@ -24,7 +24,11 @@ from multiple components and prompt values. Prompt classes and functions make co
|
|||||||
SystemMessagePromptTemplate
|
SystemMessagePromptTemplate
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
from langchain_core.prompts.base import (
|
||||||
|
BasePromptTemplate,
|
||||||
|
aformat_document,
|
||||||
|
format_document,
|
||||||
|
)
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
BaseChatPromptTemplate,
|
BaseChatPromptTemplate,
|
||||||
@@ -67,6 +71,7 @@ __all__ = [
|
|||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"load_prompt",
|
"load_prompt",
|
||||||
"format_document",
|
"format_document",
|
||||||
|
"aformat_document",
|
||||||
"check_valid_template",
|
"check_valid_template",
|
||||||
"get_template_variables",
|
"get_template_variables",
|
||||||
"jinja2_formatter",
|
"jinja2_formatter",
|
||||||
|
@@ -251,6 +251,21 @@ class BasePromptTemplate(
|
|||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
|
||||||
|
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||||
|
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||||
|
if len(missing_metadata) > 0:
|
||||||
|
required_metadata = [
|
||||||
|
iv for iv in prompt.input_variables if iv != "page_content"
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
f"Document prompt requires documents to have metadata variables: "
|
||||||
|
f"{required_metadata}. Received document with missing metadata: "
|
||||||
|
f"{list(missing_metadata)}."
|
||||||
|
)
|
||||||
|
return {k: base_info[k] for k in prompt.input_variables}
|
||||||
|
|
||||||
|
|
||||||
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||||
"""Format a document into a string based on a prompt template.
|
"""Format a document into a string based on a prompt template.
|
||||||
|
|
||||||
@@ -285,16 +300,30 @@ def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
|||||||
format_document(doc, prompt)
|
format_document(doc, prompt)
|
||||||
>>> "Page 1: This is a joke"
|
>>> "Page 1: This is a joke"
|
||||||
"""
|
"""
|
||||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
return prompt.format(**_get_document_info(doc, prompt))
|
||||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
|
||||||
if len(missing_metadata) > 0:
|
|
||||||
required_metadata = [
|
async def aformat_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||||
iv for iv in prompt.input_variables if iv != "page_content"
|
"""Format a document into a string based on a prompt template.
|
||||||
]
|
|
||||||
raise ValueError(
|
First, this pulls information from the document from two sources:
|
||||||
f"Document prompt requires documents to have metadata variables: "
|
|
||||||
f"{required_metadata}. Received document with missing metadata: "
|
1. `page_content`:
|
||||||
f"{list(missing_metadata)}."
|
This takes the information from the `document.page_content`
|
||||||
)
|
and assigns it to a variable named `page_content`.
|
||||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
2. metadata:
|
||||||
return prompt.format(**document_info)
|
This takes information from `document.metadata` and assigns
|
||||||
|
it to variables of the same name.
|
||||||
|
|
||||||
|
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc: Document, the page_content and metadata will be used to create
|
||||||
|
the final string.
|
||||||
|
prompt: BasePromptTemplate, will be used to format the page_content
|
||||||
|
and metadata into the final string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string of the document formatted.
|
||||||
|
"""
|
||||||
|
return await prompt.aformat(**_get_document_info(doc, prompt))
|
||||||
|
@@ -10,6 +10,7 @@ EXPECTED_ALL = [
|
|||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
"FewShotChatMessagePromptTemplate",
|
"FewShotChatMessagePromptTemplate",
|
||||||
"format_document",
|
"format_document",
|
||||||
|
"aformat_document",
|
||||||
"HumanMessagePromptTemplate",
|
"HumanMessagePromptTemplate",
|
||||||
"MessagesPlaceholder",
|
"MessagesPlaceholder",
|
||||||
"PipelinePromptTemplate",
|
"PipelinePromptTemplate",
|
||||||
|
@@ -4,7 +4,12 @@ from typing import Optional
|
|||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
|
from langchain_core.prompts import (
|
||||||
|
BasePromptTemplate,
|
||||||
|
PromptTemplate,
|
||||||
|
aformat_document,
|
||||||
|
format_document,
|
||||||
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
@@ -39,7 +44,7 @@ async def _aget_relevant_documents(
|
|||||||
) -> str:
|
) -> str:
|
||||||
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
||||||
return document_separator.join(
|
return document_separator.join(
|
||||||
format_document(doc, document_prompt) for doc in docs
|
[await aformat_document(doc, document_prompt) for doc in docs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Any, List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.prompts import PromptTemplate, format_document
|
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
|
||||||
|
|
||||||
from langchain.chains.combine_documents.reduce import (
|
from langchain.chains.combine_documents.reduce import (
|
||||||
collapse_docs,
|
collapse_docs,
|
||||||
@@ -119,7 +119,7 @@ def test__collapse_docs_metadata() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_format_doc_with_metadata() -> None:
|
async def test_format_doc_with_metadata() -> None:
|
||||||
"""Test format doc on a valid document."""
|
"""Test format doc on a valid document."""
|
||||||
doc = Document(page_content="foo", metadata={"bar": "baz"})
|
doc = Document(page_content="foo", metadata={"bar": "baz"})
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
@@ -128,9 +128,11 @@ def test_format_doc_with_metadata() -> None:
|
|||||||
expected_output = "foo, baz"
|
expected_output = "foo, baz"
|
||||||
output = format_document(doc, prompt)
|
output = format_document(doc, prompt)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
output = await aformat_document(doc, prompt)
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_format_doc_missing_metadata() -> None:
|
async def test_format_doc_missing_metadata() -> None:
|
||||||
"""Test format doc on a document with missing metadata."""
|
"""Test format doc on a document with missing metadata."""
|
||||||
doc = Document(page_content="foo")
|
doc = Document(page_content="foo")
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
@@ -138,3 +140,5 @@ def test_format_doc_missing_metadata() -> None:
|
|||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
format_document(doc, prompt)
|
format_document(doc, prompt)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await aformat_document(doc, prompt)
|
||||||
|
Reference in New Issue
Block a user