Added support for chat_history (#7555)

#7469

Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
Leonid Kuligin 2023-07-11 21:27:26 +02:00 committed by GitHub
parent 406a9dc11f
commit 6674b33cf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 55 deletions

View File

@ -1,6 +1,6 @@
"""Wrapper around Google VertexAI chat-based models.""" """Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
@ -22,55 +22,46 @@ from langchain.schema.messages import (
) )
from langchain.utilities.vertexai import raise_vertex_import_error from langchain.utilities.vertexai import raise_vertex_import_error
if TYPE_CHECKING:
@dataclass from vertexai.language_models import ChatMessage
class _MessagePair:
"""InputOutputTextPair represents a pair of input and output texts."""
question: HumanMessage
answer: AIMessage
@dataclass @dataclass
class _ChatHistory: class _ChatHistory:
"""InputOutputTextPair represents a pair of input and output texts.""" """Represents a context and a history of messages."""
history: List[_MessagePair] = field(default_factory=list) history: List["ChatMessage"] = field(default_factory=list)
system_message: Optional[SystemMessage] = None context: Optional[str] = None
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
"""Parse a sequence of messages into history. """Parse a sequence of messages into history.
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
AIMessage, ...). CodeChat does not support SystemMessage.
Args: Args:
history: The list of messages to re-create the history of the chat. history: The list of messages to re-create the history of the chat.
Returns: Returns:
A parsed chat history. A parsed chat history.
Raises: Raises:
ValueError: If a sequence of message is odd, or a human message is not followed ValueError: If a sequence of message has a SystemMessage not at the
by a message from AI (e.g., Human, Human, AI or AI, AI, Human). first place.
""" """
if not history: from vertexai.language_models import ChatMessage
return _ChatHistory()
first_message = history[0] vertex_messages, context = [], None
system_message = first_message if isinstance(first_message, SystemMessage) else None for i, message in enumerate(history):
chat_history = _ChatHistory(system_message=system_message) if i == 0 and isinstance(message, SystemMessage):
messages_left = history[1:] if system_message else history context = message.content
if len(messages_left) % 2 != 0: elif isinstance(message, AIMessage):
raise ValueError( vertex_message = ChatMessage(content=message.content, author="bot")
f"Amount of messages in history should be even, got {len(messages_left)}!" vertex_messages.append(vertex_message)
) elif isinstance(message, HumanMessage):
for question, answer in zip(messages_left[::2], messages_left[1::2]): vertex_message = ChatMessage(content=message.content, author="user")
if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage): vertex_messages.append(vertex_message)
else:
raise ValueError( raise ValueError(
"A human message should follow a bot one, " f"Unexpected message with type {type(message)} at the position {i}."
f"got {question.type}, {answer.type}."
) )
chat_history.history.append(_MessagePair(question=question, answer=answer)) chat_history = _ChatHistory(context=context, history=vertex_messages)
return chat_history return chat_history
@ -126,16 +117,15 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
raise ValueError( raise ValueError(
f"Last message in the list should be from human, got {question.type}." f"Last message in the list should be from human, got {question.type}."
) )
history = _parse_chat_history(messages[:-1]) history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None context = history.context if history.context else None
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
if not self.is_codey_model: if not self.is_codey_model:
chat = self.client.start_chat(context=context, **params) chat = self.client.start_chat(
context=context, message_history=history.history, **params
)
else: else:
chat = self.client.start_chat(**params) chat = self.client.start_chat(**params)
for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **params) response = chat.send_message(question.content, **params)
text = self._enforce_stop_words(response.text, stop) text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])

View File

@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None:
Raises: Raises:
ImportError: an ImportError that mentions a required version of the SDK. ImportError: an ImportError that mentions a required version of the SDK.
""" """
sdk = "'google-cloud-aiplatform>=1.26.0'" sdk = "'google-cloud-aiplatform>=1.26.1'"
raise ImportError( raise ImportError(
"Could not import VertexAI. Please, install it with " f"pip install {sdk}" "Could not import VertexAI. Please, install it with " f"pip install {sdk}"
) )

View File

@ -12,7 +12,7 @@ from unittest.mock import Mock, patch
import pytest import pytest
from langchain.chat_models import ChatVertexAI from langchain.chat_models import ChatVertexAI
from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history from langchain.chat_models.vertexai import _parse_chat_history
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
@ -43,6 +43,8 @@ def test_vertexai_single_call_with_context() -> None:
def test_parse_chat_history_correct() -> None: def test_parse_chat_history_correct() -> None:
from vertexai.language_models import ChatMessage
text_context = ( text_context = (
"My name is Ned. You are my personal assistant. My " "My name is Ned. You are my personal assistant. My "
"favorite movies are Lord of the Rings and Hobbit." "favorite movies are Lord of the Rings and Hobbit."
@ -58,22 +60,14 @@ def test_parse_chat_history_correct() -> None:
) )
answer = AIMessage(content=text_answer) answer = AIMessage(content=text_answer)
history = _parse_chat_history([context, question, answer, question, answer]) history = _parse_chat_history([context, question, answer, question, answer])
assert history.system_message == context assert history.context == context.content
assert len(history.history) == 2 assert len(history.history) == 4
assert history.history[0] == _MessagePair(question=question, answer=answer) assert history.history == [
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
def test_parse_chat_history_wrong_sequence() -> None: ChatMessage(content=text_question, author="user"),
text_question = ( ChatMessage(content=text_answer, author="bot"),
"Hello, could you recommend a good movie for me to watch this evening, please?" ]
)
question = HumanMessage(content=text_question)
with pytest.raises(ValueError) as exc_info:
_ = _parse_chat_history([question, question])
assert (
str(exc_info.value)
== "A human message should follow a bot one, got human, human."
)
def test_vertexai_single_call_failes_no_message() -> None: def test_vertexai_single_call_failes_no_message() -> None: