langchain_google_vertexai[patch]: Add support for SystemMessage for Gemini chat model (#15933)

- **Description:** In Google Vertex AI, Gemini Chat models currently
doesn't have a support for SystemMessage. This PR adds support for it
only if a user provides additional convert_system_message_to_human flag
during model initialization (in this case, SystemMessage would be
prepended to the first HumanMessage). **NOTE:** The implementation is
similar to #14824


- **Twitter handle:** rajesh_thallam

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Rajesh Thallam
2024-01-18 10:22:07 -08:00
committed by GitHub
parent 65b231d40b
commit 6bc6d64a12
4 changed files with 151 additions and 13 deletions

View File

@@ -111,7 +111,9 @@ def _is_url(s: str) -> bool:
def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
history: List[BaseMessage],
project: Optional[str] = None,
convert_system_message_to_human: Optional[bool] = False,
) -> List[Content]:
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
@@ -155,9 +157,25 @@ def _parse_chat_history_gemini(
return [_convert_to_prompt(part) for part in raw_content]
vertex_messages = []
raw_system_message = None
for i, message in enumerate(history):
if i == 0 and isinstance(message, SystemMessage):
raise ValueError("SystemMessages are not yet supported!")
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
raw_function_call = message.additional_kwargs.get("function_call")
role = "model"
@@ -170,6 +188,8 @@ def _parse_chat_history_gemini(
)
gapic_part = GapicPart(function_call=function_call)
parts = [Part._from_gapic(gapic_part)]
else:
parts = _convert_to_parts(message)
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message)
@@ -188,6 +208,15 @@ def _parse_chat_history_gemini(
f"Unexpected message with type {type(message)} at the position {i}."
)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message) + parts
raw_system_message = None
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
@@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
model_name: str = "chat-bison"
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
@classmethod
def is_lc_serializable(self) -> bool:
@@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
@@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
@@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented

View File

@@ -182,3 +182,36 @@ def test_vertexai_single_call_fails_no_message() -> None:
str(exc_info.value)
== "You should provide at least one message to start the chat!"
)
@pytest.mark.parametrize("model_name", ["gemini-pro"])
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
model = ChatVertexAI(model_name=model_name)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: str) -> None:
if model_name:
model = ChatVertexAI(
model_name=model_name, convert_system_message_to_human=True
)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

View File

@@ -13,6 +13,7 @@ from vertexai.language_models import ChatMessage, InputOutputTextPair # type: i
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_chat_history_gemini,
_parse_examples,
)
@@ -112,6 +113,24 @@ def test_parse_chat_history_correct() -> None:
]
def test_parse_history_gemini() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0].role == "user"
assert history[0].parts[0].text == system_input
assert history[0].parts[1].text == text_question1
assert history[1].role == "model"
assert history[1].parts[0].text == text_answer1
def test_default_params_palm() -> None:
user_prompt = "Hello"