mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 17:36:00 +00:00
langchain_google_vertexai[patch]: Add support for SystemMessage for Gemini chat model (#15933)
- **Description:** In Google Vertex AI, Gemini Chat models currently doesn't have a support for SystemMessage. This PR adds support for it only if a user provides additional convert_system_message_to_human flag during model initialization (in this case, SystemMessage would be prepended to the first HumanMessage). **NOTE:** The implementation is similar to #14824 - **Twitter handle:** rajesh_thallam --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -111,7 +111,9 @@ def _is_url(s: str) -> bool:
|
||||
|
||||
|
||||
def _parse_chat_history_gemini(
|
||||
history: List[BaseMessage], project: Optional[str]
|
||||
history: List[BaseMessage],
|
||||
project: Optional[str] = None,
|
||||
convert_system_message_to_human: Optional[bool] = False,
|
||||
) -> List[Content]:
|
||||
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
||||
if isinstance(part, str):
|
||||
@@ -155,9 +157,25 @@ def _parse_chat_history_gemini(
|
||||
return [_convert_to_prompt(part) for part in raw_content]
|
||||
|
||||
vertex_messages = []
|
||||
raw_system_message = None
|
||||
for i, message in enumerate(history):
|
||||
if i == 0 and isinstance(message, SystemMessage):
|
||||
raise ValueError("SystemMessages are not yet supported!")
|
||||
if (
|
||||
i == 0
|
||||
and isinstance(message, SystemMessage)
|
||||
and not convert_system_message_to_human
|
||||
):
|
||||
raise ValueError(
|
||||
"""SystemMessages are not yet supported!
|
||||
|
||||
To automatically convert the leading SystemMessage to a HumanMessage,
|
||||
set `convert_system_message_to_human` to True. Example:
|
||||
|
||||
llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
|
||||
"""
|
||||
)
|
||||
elif i == 0 and isinstance(message, SystemMessage):
|
||||
raw_system_message = message
|
||||
continue
|
||||
elif isinstance(message, AIMessage):
|
||||
raw_function_call = message.additional_kwargs.get("function_call")
|
||||
role = "model"
|
||||
@@ -170,6 +188,8 @@ def _parse_chat_history_gemini(
|
||||
)
|
||||
gapic_part = GapicPart(function_call=function_call)
|
||||
parts = [Part._from_gapic(gapic_part)]
|
||||
else:
|
||||
parts = _convert_to_parts(message)
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
parts = _convert_to_parts(message)
|
||||
@@ -188,6 +208,15 @@ def _parse_chat_history_gemini(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
|
||||
if raw_system_message:
|
||||
if role == "model":
|
||||
raise ValueError(
|
||||
"SystemMessage should be followed by a HumanMessage and "
|
||||
"not by AIMessage."
|
||||
)
|
||||
parts = _convert_to_parts(raw_system_message) + parts
|
||||
raw_system_message = None
|
||||
|
||||
vertex_message = Content(role=role, parts=parts)
|
||||
vertex_messages.append(vertex_message)
|
||||
return vertex_messages
|
||||
@@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
model_name: str = "chat-bison"
|
||||
"Underlying model name."
|
||||
examples: Optional[List[BaseMessage]] = None
|
||||
convert_system_message_to_human: bool = False
|
||||
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
||||
|
||||
Gemini does not support system messages; any unsupported messages will
|
||||
raise an error."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
@@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
history_gemini = _parse_chat_history_gemini(
|
||||
messages,
|
||||
project=self.project,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
|
||||
@@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
history_gemini = _parse_chat_history_gemini(
|
||||
messages,
|
||||
project=self.project,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
# set param to `functions` until core tool/function calling implemented
|
||||
@@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
history_gemini = _parse_chat_history_gemini(
|
||||
messages,
|
||||
project=self.project,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
# set param to `functions` until core tool/function calling implemented
|
||||
|
@@ -182,3 +182,36 @@ def test_vertexai_single_call_fails_no_message() -> None:
|
||||
str(exc_info.value)
|
||||
== "You should provide at least one message to start the chat!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["gemini-pro"])
|
||||
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
with pytest.raises(ValueError):
|
||||
model([system_message, message1, message2, message3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_chat_vertexai_system_message(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(
|
||||
model_name=model_name, convert_system_message_to_human=True
|
||||
)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
response = model([system_message, message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
@@ -13,6 +13,7 @@ from vertexai.language_models import ChatMessage, InputOutputTextPair # type: i
|
||||
from langchain_google_vertexai.chat_models import (
|
||||
ChatVertexAI,
|
||||
_parse_chat_history,
|
||||
_parse_chat_history_gemini,
|
||||
_parse_examples,
|
||||
)
|
||||
|
||||
@@ -112,6 +113,24 @@ def test_parse_chat_history_correct() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_parse_history_gemini() -> None:
|
||||
system_input = "You're supposed to answer math questions."
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content=system_input)
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
messages = [system_message, message1, message2, message3]
|
||||
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
|
||||
assert len(history) == 3
|
||||
assert history[0].role == "user"
|
||||
assert history[0].parts[0].text == system_input
|
||||
assert history[0].parts[1].text == text_question1
|
||||
assert history[1].role == "model"
|
||||
assert history[1].parts[0].text == text_answer1
|
||||
|
||||
|
||||
def test_default_params_palm() -> None:
|
||||
user_prompt = "Hello"
|
||||
|
||||
|
Reference in New Issue
Block a user