mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 22:04:37 +00:00
propagate callbacks through load_summarize_chain (#7565)
This lets you pass callbacks when you create the summarize chain: ``` summarize = load_summarize_chain(llm, chain_type="map_reduce", callbacks=[my_callbacks]) summary = summarize(documents) ``` See #5572 for a similar surgical fix. tagging @hwchase17 for callbacks work <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
404d103c41
commit
5d765408ce
@ -1,6 +1,7 @@
|
|||||||
"""Load summarizing chains."""
|
"""Load summarizing chains."""
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||||
@ -49,16 +50,22 @@ def _load_map_reduce_chain(
|
|||||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
token_max: int = 3000,
|
token_max: int = 3000,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
map_chain = LLMChain(
|
||||||
|
llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks
|
||||||
|
)
|
||||||
_reduce_llm = reduce_llm or llm
|
_reduce_llm = reduce_llm or llm
|
||||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
reduce_chain = LLMChain(
|
||||||
|
llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks
|
||||||
|
)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
combine_documents_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
if collapse_prompt is None:
|
if collapse_prompt is None:
|
||||||
collapse_chain = None
|
collapse_chain = None
|
||||||
@ -74,6 +81,7 @@ def _load_map_reduce_chain(
|
|||||||
llm=_collapse_llm,
|
llm=_collapse_llm,
|
||||||
prompt=collapse_prompt,
|
prompt=collapse_prompt,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
),
|
),
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
)
|
)
|
||||||
@ -82,12 +90,14 @@ def _load_map_reduce_chain(
|
|||||||
collapse_documents_chain=collapse_chain,
|
collapse_documents_chain=collapse_chain,
|
||||||
token_max=token_max,
|
token_max=token_max,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
reduce_documents_chain=reduce_documents_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user