diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index b7fb299e869..900d8e7c82e 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -172,15 +172,16 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): llm: BaseLanguageModel, retriever: BaseRetriever, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, - qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", + combine_docs_chain_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" + combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} doc_chain = load_qa_chain( llm, chain_type=chain_type, - prompt=qa_prompt, + **combine_docs_chain_kwargs, ) condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) return cls( @@ -226,15 +227,16 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): llm: BaseLanguageModel, vectorstore: VectorStore, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, - qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", + combine_docs_chain_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" + combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} doc_chain = load_qa_chain( llm, chain_type=chain_type, - prompt=qa_prompt, + **combine_docs_chain_kwargs, ) condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) return cls(