RFC: BaseRetriever.return_str()

This commit is contained in:
Bagatur
2024-01-03 15:30:42 -05:00
parent 6e90b7a91b
commit 9448f169ce
2 changed files with 32 additions and 4 deletions

View File

@@ -210,7 +210,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
def format_document(doc: Document, prompt: Optional[BasePromptTemplate] = None) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
@@ -244,16 +244,17 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
_prompt = prompt or PromptTemplate.from_template("{page_content}")
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
iv for iv in _prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
document_info = {k: base_info[k] for k in _prompt.input_variables}
return _prompt.format(**document_info)

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.prompts import format_document
from langchain_core.runnables import (
Runnable,
RunnableConfig,
@@ -285,3 +286,29 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
**kwargs,
)
return result
def return_str(
self, prompt: Optional[BasePromptTemplate] = None, separator: str = "\n\n"
) -> RunnableSerializable[RetrieverInput, str]:
"""
Example:
.. code-block:: python
from langchain_community.retrievers import FAISS
retriever = FAISS.from_texts(["hi", "bye"])
chain = {"context": retriever.return_str()} | prompt | llm | StrOutputParser()
"""
return self | partial(_transform_docs, prompt=prompt, separator=separator)
def _transform_docs(
doc_stream: Iterator[Document],
prompt: Optional[BasePromptTemplate],
separator: str,
) -> Iterator[str]:
doc = next(doc_stream)
yield format_document(doc, prompt=prompt)
for doc in doc_stream:
yield separator + format_document(doc, prompt=prompt)