From 9448f169ced26265202ef89751aa5b34ed0eef61 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 3 Jan 2024 15:30:42 -0500 Subject: [PATCH] RFC: BaseRetriever.return_str() --- libs/core/langchain_core/prompts/base.py | 9 ++++---- libs/core/langchain_core/retrievers.py | 27 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 2e1878ed27e..fb74560a940 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -210,7 +210,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): raise ValueError(f"{save_path} must be json or yaml") -def format_document(doc: Document, prompt: BasePromptTemplate) -> str: +def format_document(doc: Document, prompt: Optional[BasePromptTemplate] = None) -> str: """Format a document into a string based on a prompt template. First, this pulls information from the document from two sources: @@ -244,16 +244,17 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str: format_document(doc, prompt) >>> "Page 1: This is a joke" """ + _prompt = prompt or PromptTemplate.from_template("{page_content}") base_info = {"page_content": doc.page_content, **doc.metadata} missing_metadata = set(prompt.input_variables).difference(base_info) if len(missing_metadata) > 0: required_metadata = [ - iv for iv in prompt.input_variables if iv != "page_content" + iv for iv in _prompt.input_variables if iv != "page_content" ] raise ValueError( f"Document prompt requires documents to have metadata variables: " f"{required_metadata}. Received document with missing metadata: " f"{list(missing_metadata)}." ) - document_info = {k: base_info[k] for k in prompt.input_variables} - return prompt.format(**document_info) + document_info = {k: base_info[k] for k in _prompt.input_variables} + return _prompt.format(**document_info) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index f2156366959..e475dc2b11b 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain_core.documents import Document from langchain_core.load.dump import dumpd +from langchain_core.prompts import format_document from langchain_core.runnables import ( Runnable, RunnableConfig, @@ -285,3 +286,29 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): **kwargs, ) return result + + def return_str( + self, prompt: Optional[BasePromptTemplate] = None, separator: str = "\n\n" + ) -> RunnableSerializable[RetrieverInput, str]: + """ + + Example: + .. code-block:: python + + from langchain_community.retrievers import FAISS + + retriever = FAISS.from_texts(["hi", "bye"]) + chain = {"context": retriever.return_str()} | prompt | llm | StrOutputParser() + """ + return self | partial(_transform_docs, prompt=prompt, separator=separator) + + +def _transform_docs( + doc_stream: Iterator[Document], + prompt: Optional[BasePromptTemplate], + separator: str, +) -> Iterator[str]: + doc = next(doc_stream) + yield format_document(doc, prompt=prompt) + for doc in doc_stream: + yield separator + format_document(doc, prompt=prompt)