add input type for convo retrieval chain (#11679)

This commit is contained in:
Harrison Chase 2023-10-11 14:13:48 -07:00 committed by GitHub
parent d5e762d328
commit 9f39c23a13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,7 @@ import inspect
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -18,7 +18,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
@ -50,6 +50,11 @@ def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
return buffer return buffer
class InputType(BaseModel):
question: str
chat_history: List[CHAT_TURN_TYPE]
class BaseConversationalRetrievalChain(Chain): class BaseConversationalRetrievalChain(Chain):
"""Chain for chatting with an index.""" """Chain for chatting with an index."""
@ -87,6 +92,10 @@ class BaseConversationalRetrievalChain(Chain):
"""Input keys.""" """Input keys."""
return ["question", "chat_history"] return ["question", "chat_history"]
@property
def input_schema(self) -> Type[BaseModel]:
return InputType
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the output keys. """Return the output keys.