Add conv retrieval

This commit is contained in:
Nuno Campos
2023-12-20 13:37:35 -08:00
parent ad1ab2b566
commit 582461945f

View File

@@ -4,16 +4,19 @@ from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from operator import itemgetter
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain_core.beta.runnables.context import Context
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.base import RunnableMap
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import (
@@ -21,7 +24,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.base import RunnableChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
@@ -63,7 +66,7 @@ class InputType(BaseModel):
"""The chat history to use for retrieval."""
class BaseConversationalRetrievalChain(Chain):
class BaseConversationalRetrievalChain(RunnableChain):
"""Chain for chatting with an index."""
combine_docs_chain: BaseCombineDocumentsChain
@@ -289,6 +292,51 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""If set, enforces that the documents returned are less than this limit.
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
def as_runnable(self) -> Runnable:
context = Context.create_scope("conversational_retrieval")
get_chat_history = self.get_chat_history or _get_chat_history
def get_new_question(inputs: Dict[str, Any]):
return (
self.question_generator
if inputs["chat_history"]
else inputs["question"]
)
def get_answer(inputs: Dict[str, Any]):
return (
self.response_if_no_docs_found
if self.response_if_no_docs_found is not None
and len(inputs["input_documents"]) == 0
else self.combine_docs_chain
)
return (
RunnableMap(
question=itemgetter("question") | context.setter("question"),
chat_history=itemgetter("chat_history")
| get_chat_history
| context.setter("chat_history"),
)
| get_new_question
| context.setter("new_question")
| self.retriever
| self._reduce_tokens_below_limit
| context.setter("input_documents")
| {
"input_documents": context.getter("input_documents"),
"chat_history": context.getter("chat_history"),
"question": context.getter(
"new_question" if self.rephrase_question else "question"
),
}
| {
self.output_key: get_answer,
"source_documents": context.getter("input_documents"),
"generated_question": context.getter("new_question"),
}
)
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
@@ -400,6 +448,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
top_k_docs_for_context: int = 4
search_kwargs: dict = Field(default_factory=dict)
def as_runnable(self) -> Runnable:
return self
@property
def _chain_type(self) -> str:
return "chat-vector-db"