mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
add input type for convo retrieval chain (#11679)
This commit is contained in:
parent
d5e762d328
commit
9f39c23a13
@ -5,7 +5,7 @@ import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -18,7 +18,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
@ -50,6 +50,11 @@ def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
|
||||
return buffer
|
||||
|
||||
|
||||
class InputType(BaseModel):
|
||||
question: str
|
||||
chat_history: List[CHAT_TURN_TYPE]
|
||||
|
||||
|
||||
class BaseConversationalRetrievalChain(Chain):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
@ -87,6 +92,10 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
"""Input keys."""
|
||||
return ["question", "chat_history"]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return InputType
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
Loading…
Reference in New Issue
Block a user