mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
parse output of combine docs
This commit is contained in:
parent
c59c5f5164
commit
275e58eab8
@ -1,12 +1,13 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BaseOutputParser
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
||||
@ -42,6 +43,21 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
@abstractmethod
|
||||
@property
|
||||
def output_parser(self) -> Optional[BaseOutputParser]:
|
||||
"""Output parser to use for results of combine_docs."""
|
||||
|
||||
def combine_and_parse(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Combine documents and parse the result."""
|
||||
result, _ = self.combine_docs(docs, **kwargs)
|
||||
if self.output_parser is not None:
|
||||
return self.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
|
@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra, root_validator
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BaseOutputParser
|
||||
|
||||
|
||||
def _split_list_of_docs(
|
||||
@ -113,6 +114,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
else:
|
||||
return self.combine_document_chain
|
||||
|
||||
@property
|
||||
def output_parser(self) -> Optional[BaseOutputParser]:
|
||||
"""Output parser to use for results of combine_docs."""
|
||||
return self.combine_document_chain.output_parser
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.base import BaseOutputParser, BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
@ -74,6 +74,11 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def output_parser(self) -> Optional[BaseOutputParser]:
|
||||
"""Output parser to use for results of combine_docs."""
|
||||
return self.refine_llm_chain.prompt.output_parser
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
base_info = {"page_content": docs[0].page_content}
|
||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.base import BaseOutputParser, BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
@ -78,6 +78,11 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
@property
|
||||
def output_parser(self) -> Optional[BaseOutputParser]:
|
||||
"""Output parser to use for results of combine_docs."""
|
||||
return self.llm_chain.prompt.output_parser
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user