diff --git a/libs/langchain/langchain/chains/summarize/__init__.py b/libs/langchain/langchain/chains/summarize/__init__.py index dcb7fb91e62..f2e0d352fd5 100644 --- a/libs/langchain/langchain/chains/summarize/__init__.py +++ b/libs/langchain/langchain/chains/summarize/__init__.py @@ -1,166 +1,6 @@ -"""Load summarizing chains.""" -from typing import Any, Mapping, Optional, Protocol +from langchain.chains.summarize.chain import ( + LoadingCallable, + load_summarize_chain, +) -from langchain_core.callbacks import Callbacks -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate - -from langchain.chains.combine_documents.base import BaseCombineDocumentsChain -from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain -from langchain.chains.combine_documents.reduce import ReduceDocumentsChain -from langchain.chains.combine_documents.refine import RefineDocumentsChain -from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.chains.llm import LLMChain -from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt - - -class LoadingCallable(Protocol): - """Interface for loading the combine documents chain.""" - - def __call__( - self, llm: BaseLanguageModel, **kwargs: Any - ) -> BaseCombineDocumentsChain: - """Callable to load the combine documents chain.""" - - -def _load_stuff_chain( - llm: BaseLanguageModel, - prompt: BasePromptTemplate = stuff_prompt.PROMPT, - document_variable_name: str = "text", - verbose: Optional[bool] = None, - **kwargs: Any, -) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type] - # TODO: document prompt - return StuffDocumentsChain( - llm_chain=llm_chain, - document_variable_name=document_variable_name, - verbose=verbose, # type: ignore[arg-type] - **kwargs, - ) - - -def _load_map_reduce_chain( - llm: BaseLanguageModel, - map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, - combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, - combine_document_variable_name: str = "text", - map_reduce_document_variable_name: str = "text", - collapse_prompt: Optional[BasePromptTemplate] = None, - reduce_llm: Optional[BaseLanguageModel] = None, - collapse_llm: Optional[BaseLanguageModel] = None, - verbose: Optional[bool] = None, - token_max: int = 3000, - callbacks: Callbacks = None, - *, - collapse_max_retries: Optional[int] = None, - **kwargs: Any, -) -> MapReduceDocumentsChain: - map_chain = LLMChain( - llm=llm, - prompt=map_prompt, - verbose=verbose, # type: ignore[arg-type] - callbacks=callbacks, # type: ignore[arg-type] - ) - _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain( - llm=_reduce_llm, - prompt=combine_prompt, - verbose=verbose, # type: ignore[arg-type] - callbacks=callbacks, # type: ignore[arg-type] - ) - # TODO: document prompt - combine_documents_chain = StuffDocumentsChain( - llm_chain=reduce_chain, - document_variable_name=combine_document_variable_name, - verbose=verbose, # type: ignore[arg-type] - callbacks=callbacks, - ) - if collapse_prompt is None: - collapse_chain = None - if collapse_llm is not None: - raise ValueError( - "collapse_llm provided, but collapse_prompt was not: please " - "provide one or stop providing collapse_llm." - ) - else: - _collapse_llm = collapse_llm or llm - collapse_chain = StuffDocumentsChain( - llm_chain=LLMChain( - llm=_collapse_llm, - prompt=collapse_prompt, - verbose=verbose, # type: ignore[arg-type] - callbacks=callbacks, - ), - document_variable_name=combine_document_variable_name, - ) - reduce_documents_chain = ReduceDocumentsChain( - combine_documents_chain=combine_documents_chain, - collapse_documents_chain=collapse_chain, - token_max=token_max, - verbose=verbose, # type: ignore[arg-type] - callbacks=callbacks, - collapse_max_retries=collapse_max_retries, - ) - return MapReduceDocumentsChain( - llm_chain=map_chain, - reduce_documents_chain=reduce_documents_chain, - document_variable_name=map_reduce_document_variable_name, - verbose=verbose, # type: ignore[arg-type] - callbacks=callbacks, - **kwargs, - ) - - -def _load_refine_chain( - llm: BaseLanguageModel, - question_prompt: BasePromptTemplate = refine_prompts.PROMPT, - refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, - document_variable_name: str = "text", - initial_response_name: str = "existing_answer", - refine_llm: Optional[BaseLanguageModel] = None, - verbose: Optional[bool] = None, - **kwargs: Any, -) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type] - _refine_llm = refine_llm or llm - refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type] - return RefineDocumentsChain( - initial_llm_chain=initial_chain, - refine_llm_chain=refine_chain, - document_variable_name=document_variable_name, - initial_response_name=initial_response_name, - verbose=verbose, # type: ignore[arg-type] - **kwargs, - ) - - -def load_summarize_chain( - llm: BaseLanguageModel, - chain_type: str = "stuff", - verbose: Optional[bool] = None, - **kwargs: Any, -) -> BaseCombineDocumentsChain: - """Load summarizing chain. - - Args: - llm: Language Model to use in the chain. - chain_type: Type of document combining chain to use. Should be one of "stuff", - "map_reduce", and "refine". - verbose: Whether chains should be run in verbose mode or not. Note that this - applies to all chains that make up the final chain. - - Returns: - A chain to use for summarizing. - """ - loader_mapping: Mapping[str, LoadingCallable] = { - "stuff": _load_stuff_chain, - "map_reduce": _load_map_reduce_chain, - "refine": _load_refine_chain, - } - if chain_type not in loader_mapping: - raise ValueError( - f"Got unsupported chain type: {chain_type}. " - f"Should be one of {loader_mapping.keys()}" - ) - return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) +__all__ = ["LoadingCallable", "load_summarize_chain"] diff --git a/libs/langchain/langchain/chains/summarize/chain.py b/libs/langchain/langchain/chains/summarize/chain.py new file mode 100644 index 00000000000..dcb7fb91e62 --- /dev/null +++ b/libs/langchain/langchain/chains/summarize/chain.py @@ -0,0 +1,166 @@ +"""Load summarizing chains.""" +from typing import Any, Mapping, Optional, Protocol + +from langchain_core.callbacks import Callbacks +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate + +from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain +from langchain.chains.combine_documents.reduce import ReduceDocumentsChain +from langchain.chains.combine_documents.refine import RefineDocumentsChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.llm import LLMChain +from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt + + +class LoadingCallable(Protocol): + """Interface for loading the combine documents chain.""" + + def __call__( + self, llm: BaseLanguageModel, **kwargs: Any + ) -> BaseCombineDocumentsChain: + """Callable to load the combine documents chain.""" + + +def _load_stuff_chain( + llm: BaseLanguageModel, + prompt: BasePromptTemplate = stuff_prompt.PROMPT, + document_variable_name: str = "text", + verbose: Optional[bool] = None, + **kwargs: Any, +) -> StuffDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type] + # TODO: document prompt + return StuffDocumentsChain( + llm_chain=llm_chain, + document_variable_name=document_variable_name, + verbose=verbose, # type: ignore[arg-type] + **kwargs, + ) + + +def _load_map_reduce_chain( + llm: BaseLanguageModel, + map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, + combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, + combine_document_variable_name: str = "text", + map_reduce_document_variable_name: str = "text", + collapse_prompt: Optional[BasePromptTemplate] = None, + reduce_llm: Optional[BaseLanguageModel] = None, + collapse_llm: Optional[BaseLanguageModel] = None, + verbose: Optional[bool] = None, + token_max: int = 3000, + callbacks: Callbacks = None, + *, + collapse_max_retries: Optional[int] = None, + **kwargs: Any, +) -> MapReduceDocumentsChain: + map_chain = LLMChain( + llm=llm, + prompt=map_prompt, + verbose=verbose, # type: ignore[arg-type] + callbacks=callbacks, # type: ignore[arg-type] + ) + _reduce_llm = reduce_llm or llm + reduce_chain = LLMChain( + llm=_reduce_llm, + prompt=combine_prompt, + verbose=verbose, # type: ignore[arg-type] + callbacks=callbacks, # type: ignore[arg-type] + ) + # TODO: document prompt + combine_documents_chain = StuffDocumentsChain( + llm_chain=reduce_chain, + document_variable_name=combine_document_variable_name, + verbose=verbose, # type: ignore[arg-type] + callbacks=callbacks, + ) + if collapse_prompt is None: + collapse_chain = None + if collapse_llm is not None: + raise ValueError( + "collapse_llm provided, but collapse_prompt was not: please " + "provide one or stop providing collapse_llm." + ) + else: + _collapse_llm = collapse_llm or llm + collapse_chain = StuffDocumentsChain( + llm_chain=LLMChain( + llm=_collapse_llm, + prompt=collapse_prompt, + verbose=verbose, # type: ignore[arg-type] + callbacks=callbacks, + ), + document_variable_name=combine_document_variable_name, + ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_chain, + token_max=token_max, + verbose=verbose, # type: ignore[arg-type] + callbacks=callbacks, + collapse_max_retries=collapse_max_retries, + ) + return MapReduceDocumentsChain( + llm_chain=map_chain, + reduce_documents_chain=reduce_documents_chain, + document_variable_name=map_reduce_document_variable_name, + verbose=verbose, # type: ignore[arg-type] + callbacks=callbacks, + **kwargs, + ) + + +def _load_refine_chain( + llm: BaseLanguageModel, + question_prompt: BasePromptTemplate = refine_prompts.PROMPT, + refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, + document_variable_name: str = "text", + initial_response_name: str = "existing_answer", + refine_llm: Optional[BaseLanguageModel] = None, + verbose: Optional[bool] = None, + **kwargs: Any, +) -> RefineDocumentsChain: + initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type] + _refine_llm = refine_llm or llm + refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type] + return RefineDocumentsChain( + initial_llm_chain=initial_chain, + refine_llm_chain=refine_chain, + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + verbose=verbose, # type: ignore[arg-type] + **kwargs, + ) + + +def load_summarize_chain( + llm: BaseLanguageModel, + chain_type: str = "stuff", + verbose: Optional[bool] = None, + **kwargs: Any, +) -> BaseCombineDocumentsChain: + """Load summarizing chain. + + Args: + llm: Language Model to use in the chain. + chain_type: Type of document combining chain to use. Should be one of "stuff", + "map_reduce", and "refine". + verbose: Whether chains should be run in verbose mode or not. Note that this + applies to all chains that make up the final chain. + + Returns: + A chain to use for summarizing. + """ + loader_mapping: Mapping[str, LoadingCallable] = { + "stuff": _load_stuff_chain, + "map_reduce": _load_map_reduce_chain, + "refine": _load_refine_chain, + } + if chain_type not in loader_mapping: + raise ValueError( + f"Got unsupported chain type: {chain_type}. " + f"Should be one of {loader_mapping.keys()}" + ) + return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)