mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
propagate callbacks to ConversationalRetrievalChain (#5572)
# Allow callbacks to monitor ConversationalRetrievalChain <!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> I ran into an issue where load_qa_chain was not passing the callbacks down to the child LLM chains, and so made sure that callbacks are propagated. There are probably more improvements to do here but this seemed like a good place to stop. Note that I saw a lot of references to callbacks_manager, which seems to be deprecated. I left that code alone for now. ## Before submitting <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @agola11 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 -->
This commit is contained in:
parent
3294774148
commit
ec0dd6e34a
@ -12,6 +12,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
@ -204,6 +205,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
verbose: bool = False,
|
||||
condense_question_llm: Optional[BaseLanguageModel] = None,
|
||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
"""Load chain from LLM."""
|
||||
@ -212,17 +214,22 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
llm,
|
||||
chain_type=chain_type,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
**combine_docs_chain_kwargs,
|
||||
)
|
||||
|
||||
_llm = condense_question_llm or llm
|
||||
condense_question_chain = LLMChain(
|
||||
llm=_llm, prompt=condense_question_prompt, verbose=verbose
|
||||
llm=_llm,
|
||||
prompt=condense_question_prompt,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=condense_question_chain,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -264,6 +271,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
chain_type: str = "stuff",
|
||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
"""Load chain from LLM."""
|
||||
@ -271,12 +279,16 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
doc_chain = load_qa_chain(
|
||||
llm,
|
||||
chain_type=chain_type,
|
||||
callbacks=callbacks,
|
||||
**combine_docs_chain_kwargs,
|
||||
)
|
||||
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
|
||||
condense_question_chain = LLMChain(
|
||||
llm=llm, prompt=condense_question_prompt, callbacks=callbacks
|
||||
)
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=condense_question_chain,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@ -35,10 +36,15 @@ def _load_map_rerank_chain(
|
||||
rank_key: str = "score",
|
||||
answer_key: str = "answer",
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> MapRerankDocumentsChain:
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
@ -57,11 +63,16 @@ def _load_stuff_chain(
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=_prompt, verbose=verbose, callback_manager=callback_manager
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
@ -84,6 +95,7 @@ def _load_map_reduce_chain(
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
_question_prompt = (
|
||||
@ -97,6 +109,7 @@ def _load_map_reduce_chain(
|
||||
prompt=_question_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(
|
||||
@ -104,6 +117,7 @@ def _load_map_reduce_chain(
|
||||
prompt=_combine_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
@ -111,6 +125,7 @@ def _load_map_reduce_chain(
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
@ -127,6 +142,7 @@ def _load_map_reduce_chain(
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
@ -139,6 +155,7 @@ def _load_map_reduce_chain(
|
||||
collapse_document_chain=collapse_chain,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -152,6 +169,7 @@ def _load_refine_chain(
|
||||
refine_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
_question_prompt = (
|
||||
@ -165,6 +183,7 @@ def _load_refine_chain(
|
||||
prompt=_question_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(
|
||||
@ -172,6 +191,7 @@ def _load_refine_chain(
|
||||
prompt=_refine_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
|
Loading…
Reference in New Issue
Block a user