mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
langchain_google_vertexai[patch]: Add support for SystemMessage for Gemini chat model (#15933)
- **Description:** In Google Vertex AI, Gemini Chat models currently doesn't have a support for SystemMessage. This PR adds support for it only if a user provides additional convert_system_message_to_human flag during model initialization (in this case, SystemMessage would be prepended to the first HumanMessage). **NOTE:** The implementation is similar to #14824 - **Twitter handle:** rajesh_thallam --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
65b231d40b
commit
6bc6d64a12
@ -11,7 +11,6 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@ -95,7 +94,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we want to construct a simple chain that takes user specified parameters:"
|
||||
"Gemini doesn't support SystemMessage at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to `True`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -106,7 +105,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' プログラミングが大好きです')"
|
||||
"AIMessage(content=\"J'aime la programmation.\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -114,6 +113,40 @@
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"system = \"You are a helpful assistant who translate English to French\"\n",
|
||||
"human = \"Translate this sentence from English to French. I love programming.\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
||||
"\n",
|
||||
"chat = ChatVertexAI(model_name=\"gemini-pro\", convert_system_message_to_human=True)\n",
|
||||
"\n",
|
||||
"chain = prompt | chat\n",
|
||||
"chain.invoke({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we want to construct a simple chain that takes user specified parameters:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' プログラミングが大好きです')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"system = (\n",
|
||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
||||
@ -121,6 +154,8 @@
|
||||
"human = \"{text}\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
||||
"\n",
|
||||
"chat = ChatVertexAI()\n",
|
||||
"\n",
|
||||
"chain = prompt | chat\n",
|
||||
"\n",
|
||||
"chain.invoke(\n",
|
||||
@ -133,7 +168,6 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
@ -352,7 +386,7 @@
|
||||
"AIMessage(content=' Why do you love programming?')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -428,8 +462,14 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"environment": {
|
||||
"kernel": "python3",
|
||||
"name": "common-cpu.m108",
|
||||
"type": "gcloud",
|
||||
"uri": "gcr.io/deeplearning-platform-release/base-cpu:m108"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -443,7 +483,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.10"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -111,7 +111,9 @@ def _is_url(s: str) -> bool:
|
||||
|
||||
|
||||
def _parse_chat_history_gemini(
|
||||
history: List[BaseMessage], project: Optional[str]
|
||||
history: List[BaseMessage],
|
||||
project: Optional[str] = None,
|
||||
convert_system_message_to_human: Optional[bool] = False,
|
||||
) -> List[Content]:
|
||||
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
||||
if isinstance(part, str):
|
||||
@ -155,9 +157,25 @@ def _parse_chat_history_gemini(
|
||||
return [_convert_to_prompt(part) for part in raw_content]
|
||||
|
||||
vertex_messages = []
|
||||
raw_system_message = None
|
||||
for i, message in enumerate(history):
|
||||
if i == 0 and isinstance(message, SystemMessage):
|
||||
raise ValueError("SystemMessages are not yet supported!")
|
||||
if (
|
||||
i == 0
|
||||
and isinstance(message, SystemMessage)
|
||||
and not convert_system_message_to_human
|
||||
):
|
||||
raise ValueError(
|
||||
"""SystemMessages are not yet supported!
|
||||
|
||||
To automatically convert the leading SystemMessage to a HumanMessage,
|
||||
set `convert_system_message_to_human` to True. Example:
|
||||
|
||||
llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
|
||||
"""
|
||||
)
|
||||
elif i == 0 and isinstance(message, SystemMessage):
|
||||
raw_system_message = message
|
||||
continue
|
||||
elif isinstance(message, AIMessage):
|
||||
raw_function_call = message.additional_kwargs.get("function_call")
|
||||
role = "model"
|
||||
@ -170,6 +188,8 @@ def _parse_chat_history_gemini(
|
||||
)
|
||||
gapic_part = GapicPart(function_call=function_call)
|
||||
parts = [Part._from_gapic(gapic_part)]
|
||||
else:
|
||||
parts = _convert_to_parts(message)
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
parts = _convert_to_parts(message)
|
||||
@ -188,6 +208,15 @@ def _parse_chat_history_gemini(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
|
||||
if raw_system_message:
|
||||
if role == "model":
|
||||
raise ValueError(
|
||||
"SystemMessage should be followed by a HumanMessage and "
|
||||
"not by AIMessage."
|
||||
)
|
||||
parts = _convert_to_parts(raw_system_message) + parts
|
||||
raw_system_message = None
|
||||
|
||||
vertex_message = Content(role=role, parts=parts)
|
||||
vertex_messages.append(vertex_message)
|
||||
return vertex_messages
|
||||
@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
model_name: str = "chat-bison"
|
||||
"Underlying model name."
|
||||
examples: Optional[List[BaseMessage]] = None
|
||||
convert_system_message_to_human: bool = False
|
||||
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
||||
|
||||
Gemini does not support system messages; any unsupported messages will
|
||||
raise an error."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
history_gemini = _parse_chat_history_gemini(
|
||||
messages,
|
||||
project=self.project,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
|
||||
@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
history_gemini = _parse_chat_history_gemini(
|
||||
messages,
|
||||
project=self.project,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
# set param to `functions` until core tool/function calling implemented
|
||||
@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
history_gemini = _parse_chat_history_gemini(
|
||||
messages,
|
||||
project=self.project,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
# set param to `functions` until core tool/function calling implemented
|
||||
|
@ -182,3 +182,36 @@ def test_vertexai_single_call_fails_no_message() -> None:
|
||||
str(exc_info.value)
|
||||
== "You should provide at least one message to start the chat!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["gemini-pro"])
|
||||
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
with pytest.raises(ValueError):
|
||||
model([system_message, message1, message2, message3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_chat_vertexai_system_message(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(
|
||||
model_name=model_name, convert_system_message_to_human=True
|
||||
)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
response = model([system_message, message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
@ -13,6 +13,7 @@ from vertexai.language_models import ChatMessage, InputOutputTextPair # type: i
|
||||
from langchain_google_vertexai.chat_models import (
|
||||
ChatVertexAI,
|
||||
_parse_chat_history,
|
||||
_parse_chat_history_gemini,
|
||||
_parse_examples,
|
||||
)
|
||||
|
||||
@ -112,6 +113,24 @@ def test_parse_chat_history_correct() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_parse_history_gemini() -> None:
|
||||
system_input = "You're supposed to answer math questions."
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content=system_input)
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
messages = [system_message, message1, message2, message3]
|
||||
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
|
||||
assert len(history) == 3
|
||||
assert history[0].role == "user"
|
||||
assert history[0].parts[0].text == system_input
|
||||
assert history[0].parts[1].text == text_question1
|
||||
assert history[1].role == "model"
|
||||
assert history[1].parts[0].text == text_answer1
|
||||
|
||||
|
||||
def test_default_params_palm() -> None:
|
||||
user_prompt = "Hello"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user