mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Patch Chat History Formatting (#3236)
While we work on solidifying the memory interfaces, handle common chat history formats. This may break linting on anyone who has been passing in `get_chat_history` . Somewhat handles #3077 Alternative to #3078 that updates the typing
This commit is contained in:
parent
8f22949dc4
commit
daee0b2b97
@ -15,16 +15,32 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.question_answering import load_qa_chain
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
# Depending on the memory type and configuration, the chat history format may differ.
|
||||||
|
# This needs to be consolidated.
|
||||||
|
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
||||||
|
|
||||||
def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str:
|
|
||||||
|
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
|
||||||
buffer = ""
|
buffer = ""
|
||||||
for human_s, ai_s in chat_history:
|
for dialogue_turn in chat_history:
|
||||||
human = "Human: " + human_s
|
if isinstance(dialogue_turn, BaseMessage):
|
||||||
ai = "Assistant: " + ai_s
|
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
|
||||||
buffer += "\n" + "\n".join([human, ai])
|
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
||||||
|
elif isinstance(dialogue_turn, tuple):
|
||||||
|
human = "Human: " + dialogue_turn[0]
|
||||||
|
ai = "Assistant: " + dialogue_turn[1]
|
||||||
|
buffer += "\n" + "\n".join([human, ai])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported chat history format: {type(dialogue_turn)}."
|
||||||
|
f" Full chat history: {chat_history} "
|
||||||
|
)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +51,7 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
question_generator: LLMChain
|
question_generator: LLMChain
|
||||||
output_key: str = "answer"
|
output_key: str = "answer"
|
||||||
return_source_documents: bool = False
|
return_source_documents: bool = False
|
||||||
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
|
get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None
|
||||||
"""Return the source documents."""
|
"""Return the source documents."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
Loading…
Reference in New Issue
Block a user