From 275e58eab83499fa055d4f7284c6c6517e9418e2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 26 Dec 2022 18:41:39 -0500 Subject: [PATCH] parse output of combine docs --- langchain/chains/combine_documents/base.py | 18 +++++++++++++++++- .../chains/combine_documents/map_reduce.py | 6 ++++++ langchain/chains/combine_documents/refine.py | 9 +++++++-- langchain/chains/combine_documents/stuff.py | 7 ++++++- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 944440e94a0..9ef86309f49 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -1,12 +1,13 @@ """Base interface for chains combining documents.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel from langchain.chains.base import Chain from langchain.docstore.document import Document +from langchain.prompts.base import BaseOutputParser class BaseCombineDocumentsChain(Chain, BaseModel, ABC): @@ -42,6 +43,21 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC): def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine documents into a single string.""" + @abstractmethod + @property + def output_parser(self) -> Optional[BaseOutputParser]: + """Output parser to use for results of combine_docs.""" + + def combine_and_parse( + self, docs: List[Document], **kwargs: Any + ) -> Union[str, List[str], Dict[str, str]]: + """Combine documents and parse the result.""" + result, _ = self.combine_docs(docs, **kwargs) + if self.output_parser is not None: + return self.output_parser.parse(result) + else: + return result + def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index dd192062f3d..6dd634ed573 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document +from langchain.prompts.base import BaseOutputParser def _split_list_of_docs( @@ -113,6 +114,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): else: return self.combine_document_chain + @property + def output_parser(self) -> Optional[BaseOutputParser]: + """Output parser to use for results of combine_docs.""" + return self.combine_document_chain.output_parser + def combine_docs( self, docs: List[Document], token_max: int = 3000, **kwargs: Any ) -> Tuple[str, dict]: diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index c91bf07089f..3d3de08fb57 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Extra, Field, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.base import BaseOutputParser, BasePromptTemplate from langchain.prompts.prompt import PromptTemplate @@ -74,6 +74,11 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): ) return values + @property + def output_parser(self) -> Optional[BaseOutputParser]: + """Output parser to use for results of combine_docs.""" + return self.refine_llm_chain.prompt.output_parser + def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" base_info = {"page_content": docs[0].page_content} diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 67bdfa7512b..ceaaa464ffa 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.base import BaseOutputParser, BasePromptTemplate from langchain.prompts.prompt import PromptTemplate @@ -78,6 +78,11 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): prompt = self.llm_chain.prompt.format(**inputs) return self.llm_chain.llm.get_num_tokens(prompt) + @property + def output_parser(self) -> Optional[BaseOutputParser]: + """Output parser to use for results of combine_docs.""" + return self.llm_chain.prompt.output_parser + def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM.""" inputs = self._get_inputs(docs, **kwargs)