Added support for examples for VertexAI chat models. (#7636)

#5278

Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
Leonid Kuligin
2023-07-14 08:03:04 +02:00
committed by GitHub
parent 45bb414be2
commit 85e1c9b348
2 changed files with 90 additions and 10 deletions

View File

@@ -7,12 +7,12 @@ pip install google-cloud-aiplatform>=1.25.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from langchain.chat_models import ChatVertexAI
from langchain.chat_models.vertexai import _parse_chat_history
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
@@ -42,6 +42,20 @@ def test_vertexai_single_call_with_context() -> None:
assert isinstance(response.content, str)
def test_vertexai_single_call_with_examples() -> None:
model = ChatVertexAI()
raw_context = "My name is Ned. You are my personal assistant."
question = "2+2"
text_question, text_answer = "4+4", "8"
inp = HumanMessage(content=text_question)
output = AIMessage(content=text_answer)
context = SystemMessage(content=raw_context)
message = HumanMessage(content=question)
response = model([context, message], examples=[inp, output])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_parse_chat_history_correct() -> None:
from vertexai.language_models import ChatMessage
@@ -92,17 +106,50 @@ def test_vertexai_args_passed() -> None:
# Mock the library to ensure the args are passed correctly
with patch(
"vertexai.language_models._language_models.ChatSession.send_message"
) as send_message:
"vertexai.language_models._language_models.ChatModel.start_chat"
) as start_chat:
mock_response = Mock(text=response_text)
send_message.return_value = mock_response
mock_chat = MagicMock()
start_chat.return_value = mock_chat
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
model = ChatVertexAI(**prompt_params)
message = HumanMessage(content=user_prompt)
response = model([message])
assert response.content == response_text
send_message.assert_called_once_with(
user_prompt,
**prompt_params,
mock_send_message.assert_called_once_with(user_prompt)
start_chat.assert_called_once_with(
context=None, message_history=[], **prompt_params
)
def test_parse_examples_correct() -> None:
from vertexai.language_models import InputOutputTextPair
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
examples = _parse_examples([question, answer, question, answer])
assert len(examples) == 2
assert examples == [
InputOutputTextPair(input_text=text_question, output_text=text_answer),
InputOutputTextPair(input_text=text_question, output_text=text_answer),
]
def test_parse_exmaples_failes_wrong_sequence() -> None:
with pytest.raises(ValueError) as exc_info:
_ = _parse_examples([AIMessage(content="a")])
print(str(exc_info.value))
assert (
str(exc_info.value)
== "Expect examples to have an even amount of messages, got 1."
)