core[minor]: Add async aformat_document method (#20037)

This commit is contained in:
Christophe Bornet
2024-04-05 16:29:53 +02:00
committed by GitHub
parent 927793d088
commit 7e5c1905b1
5 changed files with 63 additions and 19 deletions

View File

@@ -4,7 +4,12 @@ from typing import Optional
from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
@@ -39,7 +44,7 @@ async def _aget_relevant_documents(
) -> str:
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
[await aformat_document(doc, document_prompt) for doc in docs]
)

View File

@@ -4,7 +4,7 @@ from typing import Any, List
import pytest
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, format_document
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
from langchain.chains.combine_documents.reduce import (
collapse_docs,
@@ -119,7 +119,7 @@ def test__collapse_docs_metadata() -> None:
assert output == expected_output
def test_format_doc_with_metadata() -> None:
async def test_format_doc_with_metadata() -> None:
"""Test format doc on a valid document."""
doc = Document(page_content="foo", metadata={"bar": "baz"})
prompt = PromptTemplate(
@@ -128,9 +128,11 @@ def test_format_doc_with_metadata() -> None:
expected_output = "foo, baz"
output = format_document(doc, prompt)
assert output == expected_output
output = await aformat_document(doc, prompt)
assert output == expected_output
def test_format_doc_missing_metadata() -> None:
async def test_format_doc_missing_metadata() -> None:
"""Test format doc on a document with missing metadata."""
doc = Document(page_content="foo")
prompt = PromptTemplate(
@@ -138,3 +140,5 @@ def test_format_doc_missing_metadata() -> None:
)
with pytest.raises(ValueError):
format_document(doc, prompt)
with pytest.raises(ValueError):
await aformat_document(doc, prompt)