mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
Added support for chat_history (#7555)
#7469 Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
parent
406a9dc11f
commit
6674b33cf5
@ -1,6 +1,6 @@
|
|||||||
"""Wrapper around Google VertexAI chat-based models."""
|
"""Wrapper around Google VertexAI chat-based models."""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
@ -22,55 +22,46 @@ from langchain.schema.messages import (
|
|||||||
)
|
)
|
||||||
from langchain.utilities.vertexai import raise_vertex_import_error
|
from langchain.utilities.vertexai import raise_vertex_import_error
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
@dataclass
|
from vertexai.language_models import ChatMessage
|
||||||
class _MessagePair:
|
|
||||||
"""InputOutputTextPair represents a pair of input and output texts."""
|
|
||||||
|
|
||||||
question: HumanMessage
|
|
||||||
answer: AIMessage
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _ChatHistory:
|
class _ChatHistory:
|
||||||
"""InputOutputTextPair represents a pair of input and output texts."""
|
"""Represents a context and a history of messages."""
|
||||||
|
|
||||||
history: List[_MessagePair] = field(default_factory=list)
|
history: List["ChatMessage"] = field(default_factory=list)
|
||||||
system_message: Optional[SystemMessage] = None
|
context: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
||||||
"""Parse a sequence of messages into history.
|
"""Parse a sequence of messages into history.
|
||||||
|
|
||||||
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
|
|
||||||
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
|
|
||||||
AIMessage, ...). CodeChat does not support SystemMessage.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
history: The list of messages to re-create the history of the chat.
|
history: The list of messages to re-create the history of the chat.
|
||||||
Returns:
|
Returns:
|
||||||
A parsed chat history.
|
A parsed chat history.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If a sequence of message is odd, or a human message is not followed
|
ValueError: If a sequence of message has a SystemMessage not at the
|
||||||
by a message from AI (e.g., Human, Human, AI or AI, AI, Human).
|
first place.
|
||||||
"""
|
"""
|
||||||
if not history:
|
from vertexai.language_models import ChatMessage
|
||||||
return _ChatHistory()
|
|
||||||
first_message = history[0]
|
vertex_messages, context = [], None
|
||||||
system_message = first_message if isinstance(first_message, SystemMessage) else None
|
for i, message in enumerate(history):
|
||||||
chat_history = _ChatHistory(system_message=system_message)
|
if i == 0 and isinstance(message, SystemMessage):
|
||||||
messages_left = history[1:] if system_message else history
|
context = message.content
|
||||||
if len(messages_left) % 2 != 0:
|
elif isinstance(message, AIMessage):
|
||||||
raise ValueError(
|
vertex_message = ChatMessage(content=message.content, author="bot")
|
||||||
f"Amount of messages in history should be even, got {len(messages_left)}!"
|
vertex_messages.append(vertex_message)
|
||||||
)
|
elif isinstance(message, HumanMessage):
|
||||||
for question, answer in zip(messages_left[::2], messages_left[1::2]):
|
vertex_message = ChatMessage(content=message.content, author="user")
|
||||||
if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage):
|
vertex_messages.append(vertex_message)
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"A human message should follow a bot one, "
|
f"Unexpected message with type {type(message)} at the position {i}."
|
||||||
f"got {question.type}, {answer.type}."
|
|
||||||
)
|
)
|
||||||
chat_history.history.append(_MessagePair(question=question, answer=answer))
|
chat_history = _ChatHistory(context=context, history=vertex_messages)
|
||||||
return chat_history
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
@ -126,16 +117,15 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Last message in the list should be from human, got {question.type}."
|
f"Last message in the list should be from human, got {question.type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
context = history.system_message.content if history.system_message else None
|
context = history.context if history.context else None
|
||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
if not self.is_codey_model:
|
if not self.is_codey_model:
|
||||||
chat = self.client.start_chat(context=context, **params)
|
chat = self.client.start_chat(
|
||||||
|
context=context, message_history=history.history, **params
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chat = self.client.start_chat(**params)
|
chat = self.client.start_chat(**params)
|
||||||
for pair in history.history:
|
|
||||||
chat._history.append((pair.question.content, pair.answer.content))
|
|
||||||
response = chat.send_message(question.content, **params)
|
response = chat.send_message(question.content, **params)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None:
|
|||||||
Raises:
|
Raises:
|
||||||
ImportError: an ImportError that mentions a required version of the SDK.
|
ImportError: an ImportError that mentions a required version of the SDK.
|
||||||
"""
|
"""
|
||||||
sdk = "'google-cloud-aiplatform>=1.26.0'"
|
sdk = "'google-cloud-aiplatform>=1.26.1'"
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
|
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,7 @@ from unittest.mock import Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models import ChatVertexAI
|
from langchain.chat_models import ChatVertexAI
|
||||||
from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history
|
from langchain.chat_models.vertexai import _parse_chat_history
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +43,8 @@ def test_vertexai_single_call_with_context() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_parse_chat_history_correct() -> None:
|
def test_parse_chat_history_correct() -> None:
|
||||||
|
from vertexai.language_models import ChatMessage
|
||||||
|
|
||||||
text_context = (
|
text_context = (
|
||||||
"My name is Ned. You are my personal assistant. My "
|
"My name is Ned. You are my personal assistant. My "
|
||||||
"favorite movies are Lord of the Rings and Hobbit."
|
"favorite movies are Lord of the Rings and Hobbit."
|
||||||
@ -58,22 +60,14 @@ def test_parse_chat_history_correct() -> None:
|
|||||||
)
|
)
|
||||||
answer = AIMessage(content=text_answer)
|
answer = AIMessage(content=text_answer)
|
||||||
history = _parse_chat_history([context, question, answer, question, answer])
|
history = _parse_chat_history([context, question, answer, question, answer])
|
||||||
assert history.system_message == context
|
assert history.context == context.content
|
||||||
assert len(history.history) == 2
|
assert len(history.history) == 4
|
||||||
assert history.history[0] == _MessagePair(question=question, answer=answer)
|
assert history.history == [
|
||||||
|
ChatMessage(content=text_question, author="user"),
|
||||||
|
ChatMessage(content=text_answer, author="bot"),
|
||||||
def test_parse_chat_history_wrong_sequence() -> None:
|
ChatMessage(content=text_question, author="user"),
|
||||||
text_question = (
|
ChatMessage(content=text_answer, author="bot"),
|
||||||
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
]
|
||||||
)
|
|
||||||
question = HumanMessage(content=text_question)
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
_ = _parse_chat_history([question, question])
|
|
||||||
assert (
|
|
||||||
str(exc_info.value)
|
|
||||||
== "A human message should follow a bot one, got human, human."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_vertexai_single_call_failes_no_message() -> None:
|
def test_vertexai_single_call_failes_no_message() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user