This commit is contained in:
Eugene Yurtsev
2023-05-30 14:45:21 -04:00
parent fe89220aac
commit 69cc8d7f73

View File

@@ -1,23 +1,30 @@
"""Module contains supporting chains for research use case."""
import asyncio
from typing import List, Dict, Any, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
AsyncCallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
class ReadEntireDocChain(Chain):
"""Read entire document chain.
class DocReadingChain(Chain):
"""A chain that reads the document.
This chain implements a brute force approach to reading an entire document.
A brute force chain that reads an entire document (or the first N pages).
"""
chain: Chain
"""The chain to use to answer the question."""
text_splitter: TextSplitter
"""The text splitter to use to split the document into smaller chunks."""
max_num_docs: int = -1
"""The maximum number of documents to split the document into."""
@property
def input_keys(self) -> List[str]:
@@ -34,7 +41,7 @@ class ReadEntireDocChain(Chain):
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document."""
"""Process a long document synchronously."""
source_document = inputs["doc"]
if not isinstance(source_document, Document):
@@ -62,7 +69,7 @@ class ReadEntireDocChain(Chain):
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document."""
"""Process a long document asynchronously."""
doc = inputs["doc"]
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([doc])
@@ -83,7 +90,7 @@ class ReadEntireDocChain(Chain):
class ParallelApply(Chain):
"""Apply a chain in parallel."""
"""Utility chain to apply a given chain in parallel across input documents."""
chain: Chain
max_concurrency: int