mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Add conv retrieval
This commit is contained in:
@@ -4,16 +4,19 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.base import RunnableMap
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -21,7 +24,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import RunnableChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
@@ -63,7 +66,7 @@ class InputType(BaseModel):
|
||||
"""The chat history to use for retrieval."""
|
||||
|
||||
|
||||
class BaseConversationalRetrievalChain(Chain):
|
||||
class BaseConversationalRetrievalChain(RunnableChain):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
combine_docs_chain: BaseCombineDocumentsChain
|
||||
@@ -289,6 +292,51 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
"""If set, enforces that the documents returned are less than this limit.
|
||||
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
context = Context.create_scope("conversational_retrieval")
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
|
||||
def get_new_question(inputs: Dict[str, Any]):
|
||||
return (
|
||||
self.question_generator
|
||||
if inputs["chat_history"]
|
||||
else inputs["question"]
|
||||
)
|
||||
|
||||
def get_answer(inputs: Dict[str, Any]):
|
||||
return (
|
||||
self.response_if_no_docs_found
|
||||
if self.response_if_no_docs_found is not None
|
||||
and len(inputs["input_documents"]) == 0
|
||||
else self.combine_docs_chain
|
||||
)
|
||||
|
||||
return (
|
||||
RunnableMap(
|
||||
question=itemgetter("question") | context.setter("question"),
|
||||
chat_history=itemgetter("chat_history")
|
||||
| get_chat_history
|
||||
| context.setter("chat_history"),
|
||||
)
|
||||
| get_new_question
|
||||
| context.setter("new_question")
|
||||
| self.retriever
|
||||
| self._reduce_tokens_below_limit
|
||||
| context.setter("input_documents")
|
||||
| {
|
||||
"input_documents": context.getter("input_documents"),
|
||||
"chat_history": context.getter("chat_history"),
|
||||
"question": context.getter(
|
||||
"new_question" if self.rephrase_question else "question"
|
||||
),
|
||||
}
|
||||
| {
|
||||
self.output_key: get_answer,
|
||||
"source_documents": context.getter("input_documents"),
|
||||
"generated_question": context.getter("new_question"),
|
||||
}
|
||||
)
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
@@ -400,6 +448,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
top_k_docs_for_context: int = 4
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
return self
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "chat-vector-db"
|
||||
|
||||
Reference in New Issue
Block a user