mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 06:55:09 +00:00
add input type for convo retrieval chain (#11679)
This commit is contained in:
parent
d5e762d328
commit
9f39c23a13
@ -5,7 +5,7 @@ import inspect
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -18,7 +18,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|||||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.question_answering import load_qa_chain
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
@ -50,6 +50,11 @@ def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
|
|||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
class InputType(BaseModel):
|
||||||
|
question: str
|
||||||
|
chat_history: List[CHAT_TURN_TYPE]
|
||||||
|
|
||||||
|
|
||||||
class BaseConversationalRetrievalChain(Chain):
|
class BaseConversationalRetrievalChain(Chain):
|
||||||
"""Chain for chatting with an index."""
|
"""Chain for chatting with an index."""
|
||||||
|
|
||||||
@ -87,6 +92,10 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
"""Input keys."""
|
"""Input keys."""
|
||||||
return ["question", "chat_history"]
|
return ["question", "chat_history"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Return the output keys.
|
"""Return the output keys.
|
||||||
|
Loading…
Reference in New Issue
Block a user