parse output of combine docs

This commit is contained in:
Harrison Chase 2022-12-26 18:41:39 -05:00
parent c59c5f5164
commit 275e58eab8
4 changed files with 36 additions and 4 deletions

View File

@ -1,12 +1,13 @@
"""Base interface for chains combining documents.""" """Base interface for chains combining documents."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.base import BaseOutputParser
class BaseCombineDocumentsChain(Chain, BaseModel, ABC): class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
@ -42,6 +43,21 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents into a single string.""" """Combine documents into a single string."""
@abstractmethod
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
def combine_and_parse(
self, docs: List[Document], **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Combine documents and parse the result."""
result, _ = self.combine_docs(docs, **kwargs)
if self.output_parser is not None:
return self.output_parser.parse(result)
else:
return result
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
docs = inputs[self.input_key] docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction # Other keys are assumed to be needed for LLM prediction

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.base import BaseOutputParser
def _split_list_of_docs( def _split_list_of_docs(
@ -113,6 +114,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
else: else:
return self.combine_document_chain return self.combine_document_chain
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
return self.combine_document_chain.output_parser
def combine_docs( def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any self, docs: List[Document], token_max: int = 3000, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:

View File

@ -2,14 +2,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BaseOutputParser, BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -74,6 +74,11 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
) )
return values return values
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
return self.refine_llm_chain.prompt.output_parser
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain.""" """Combine by mapping first chain over all, then stuffing into final chain."""
base_info = {"page_content": docs[0].page_content} base_info = {"page_content": docs[0].page_content}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BaseOutputParser, BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -78,6 +78,11 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
prompt = self.llm_chain.prompt.format(**inputs) prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt) return self.llm_chain.llm.get_num_tokens(prompt)
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
return self.llm_chain.prompt.output_parser
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.""" """Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs) inputs = self._get_inputs(docs, **kwargs)