mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
propagate RetrievalQA chain callbacks through its own LLMChain and StuffDocumentsChain (#7853)
This is another case, similar to #5572 and #7565 where the callbacks are getting dropped during construction of the chains. tagging @hwchase17 and @agola11 for callbacks propagation <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
47eea32f6a
commit
404d103c41
@ -11,6 +11,7 @@ from pydantic import Extra, Field, root_validator
|
|||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
@ -65,11 +66,12 @@ class BaseRetrievalQA(Chain):
|
|||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: Optional[PromptTemplate] = None,
|
prompt: Optional[PromptTemplate] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseRetrievalQA:
|
) -> BaseRetrievalQA:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)
|
||||||
document_prompt = PromptTemplate(
|
document_prompt = PromptTemplate(
|
||||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||||
)
|
)
|
||||||
@ -77,9 +79,14 @@ class BaseRetrievalQA(Chain):
|
|||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
document_variable_name="context",
|
document_variable_name="context",
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
return cls(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_chain_type(
|
def from_chain_type(
|
||||||
|
Loading…
Reference in New Issue
Block a user