diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 0b15f4e41c9..147610f9764 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -4,16 +4,19 @@ from __future__ import annotations import inspect import warnings from abc import abstractmethod +from operator import itemgetter from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from langchain_core.beta.runnables.context import Context from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage from langchain_core.prompts import BasePromptTemplate from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables.config import RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables.base import RunnableMap from langchain_core.vectorstores import VectorStore from langchain.callbacks.manager import ( @@ -21,7 +24,7 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) -from langchain.chains.base import Chain +from langchain.chains.base import RunnableChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT @@ -63,7 +66,7 @@ class InputType(BaseModel): """The chat history to use for retrieval.""" -class BaseConversationalRetrievalChain(Chain): +class BaseConversationalRetrievalChain(RunnableChain): """Chain for chatting with an index.""" combine_docs_chain: BaseCombineDocumentsChain @@ -289,6 +292,51 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): """If set, enforces that the documents returned are less than this limit. This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain.""" + def as_runnable(self) -> Runnable: + context = Context.create_scope("conversational_retrieval") + get_chat_history = self.get_chat_history or _get_chat_history + + def get_new_question(inputs: Dict[str, Any]): + return ( + self.question_generator + if inputs["chat_history"] + else inputs["question"] + ) + + def get_answer(inputs: Dict[str, Any]): + return ( + self.response_if_no_docs_found + if self.response_if_no_docs_found is not None + and len(inputs["input_documents"]) == 0 + else self.combine_docs_chain + ) + + return ( + RunnableMap( + question=itemgetter("question") | context.setter("question"), + chat_history=itemgetter("chat_history") + | get_chat_history + | context.setter("chat_history"), + ) + | get_new_question + | context.setter("new_question") + | self.retriever + | self._reduce_tokens_below_limit + | context.setter("input_documents") + | { + "input_documents": context.getter("input_documents"), + "chat_history": context.getter("chat_history"), + "question": context.getter( + "new_question" if self.rephrase_question else "question" + ), + } + | { + self.output_key: get_answer, + "source_documents": context.getter("input_documents"), + "generated_question": context.getter("new_question"), + } + ) + def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: num_docs = len(docs) @@ -400,6 +448,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): top_k_docs_for_context: int = 4 search_kwargs: dict = Field(default_factory=dict) + def as_runnable(self) -> Runnable: + return self + @property def _chain_type(self) -> str: return "chat-vector-db"