diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index a7558b87375..185900ca32c 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -68,11 +68,14 @@ class BaseRetrievalQA(Chain): llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, callbacks: Callbacks = None, + llm_chain_kwargs: Optional[dict] = None, **kwargs: Any, ) -> BaseRetrievalQA: """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) - llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks) + llm_chain = LLMChain( + llm=llm, prompt=_prompt, callbacks=callbacks, **(llm_chain_kwargs or {}) + ) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" )