mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
RFC: BaseRetriever.return_str()
This commit is contained in:
@@ -210,7 +210,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
def format_document(doc: Document, prompt: Optional[BasePromptTemplate] = None) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
First, this pulls information from the document from two sources:
|
||||
@@ -244,16 +244,17 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
format_document(doc, prompt)
|
||||
>>> "Page 1: This is a joke"
|
||||
"""
|
||||
_prompt = prompt or PromptTemplate.from_template("{page_content}")
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
iv for iv in _prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
document_info = {k: base_info[k] for k in _prompt.input_variables}
|
||||
return _prompt.format(**document_info)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.prompts import format_document
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
@@ -285,3 +286,29 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
def return_str(
|
||||
self, prompt: Optional[BasePromptTemplate] = None, separator: str = "\n\n"
|
||||
) -> RunnableSerializable[RetrieverInput, str]:
|
||||
"""
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import FAISS
|
||||
|
||||
retriever = FAISS.from_texts(["hi", "bye"])
|
||||
chain = {"context": retriever.return_str()} | prompt | llm | StrOutputParser()
|
||||
"""
|
||||
return self | partial(_transform_docs, prompt=prompt, separator=separator)
|
||||
|
||||
|
||||
def _transform_docs(
|
||||
doc_stream: Iterator[Document],
|
||||
prompt: Optional[BasePromptTemplate],
|
||||
separator: str,
|
||||
) -> Iterator[str]:
|
||||
doc = next(doc_stream)
|
||||
yield format_document(doc, prompt=prompt)
|
||||
for doc in doc_stream:
|
||||
yield separator + format_document(doc, prompt=prompt)
|
||||
|
||||
Reference in New Issue
Block a user