mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
propagate callbacks through load_summarize_chain (#7565)
This lets you pass callbacks when you create the summarize chain: ``` summarize = load_summarize_chain(llm, chain_type="map_reduce", callbacks=[my_callbacks]) summary = summarize(documents) ``` See #5572 for a similar surgical fix. tagging @hwchase17 for callbacks work <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
404d103c41
commit
5d765408ce
@ -1,6 +1,7 @@
|
||||
"""Load summarizing chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
@ -49,16 +50,22 @@ def _load_map_reduce_chain(
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
||||
map_chain = LLMChain(
|
||||
llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks
|
||||
)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
reduce_chain = LLMChain(
|
||||
llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks
|
||||
)
|
||||
# TODO: document prompt
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
@ -74,6 +81,7 @@ def _load_map_reduce_chain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
)
|
||||
@ -82,12 +90,14 @@ def _load_map_reduce_chain(
|
||||
collapse_documents_chain=collapse_chain,
|
||||
token_max=token_max,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user