mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
langchain_google_vertexai[patch]: Add support for SystemMessage for Gemini chat model (#15933)
- **Description:** In Google Vertex AI, Gemini Chat models currently doesn't have a support for SystemMessage. This PR adds support for it only if a user provides additional convert_system_message_to_human flag during model initialization (in this case, SystemMessage would be prepended to the first HumanMessage). **NOTE:** The implementation is similar to #14824 - **Twitter handle:** rajesh_thallam --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
65b231d40b
commit
6bc6d64a12
@ -11,7 +11,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
@ -95,7 +94,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"If we want to construct a simple chain that takes user specified parameters:"
|
"Gemini doesn't support SystemMessage at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to `True`:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -106,7 +105,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content=' プログラミングが大好きです')"
|
"AIMessage(content=\"J'aime la programmation.\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
@ -114,6 +113,40 @@
|
|||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"source": [
|
||||||
|
"system = \"You are a helpful assistant who translate English to French\"\n",
|
||||||
|
"human = \"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
||||||
|
"\n",
|
||||||
|
"chat = ChatVertexAI(model_name=\"gemini-pro\", convert_system_message_to_human=True)\n",
|
||||||
|
"\n",
|
||||||
|
"chain = prompt | chat\n",
|
||||||
|
"chain.invoke({})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"If we want to construct a simple chain that takes user specified parameters:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=' プログラミングが大好きです')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"system = (\n",
|
"system = (\n",
|
||||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
||||||
@ -121,6 +154,8 @@
|
|||||||
"human = \"{text}\"\n",
|
"human = \"{text}\"\n",
|
||||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"chat = ChatVertexAI()\n",
|
||||||
|
"\n",
|
||||||
"chain = prompt | chat\n",
|
"chain = prompt | chat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chain.invoke(\n",
|
"chain.invoke(\n",
|
||||||
@ -133,7 +168,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"execution": {
|
||||||
@ -352,7 +386,7 @@
|
|||||||
"AIMessage(content=' Why do you love programming?')"
|
"AIMessage(content=' Why do you love programming?')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -428,8 +462,14 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"environment": {
|
||||||
|
"kernel": "python3",
|
||||||
|
"name": "common-cpu.m108",
|
||||||
|
"type": "gcloud",
|
||||||
|
"uri": "gcr.io/deeplearning-platform-release/base-cpu:m108"
|
||||||
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -443,7 +483,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.10.10"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
|
@ -111,7 +111,9 @@ def _is_url(s: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_chat_history_gemini(
|
def _parse_chat_history_gemini(
|
||||||
history: List[BaseMessage], project: Optional[str]
|
history: List[BaseMessage],
|
||||||
|
project: Optional[str] = None,
|
||||||
|
convert_system_message_to_human: Optional[bool] = False,
|
||||||
) -> List[Content]:
|
) -> List[Content]:
|
||||||
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
||||||
if isinstance(part, str):
|
if isinstance(part, str):
|
||||||
@ -155,9 +157,25 @@ def _parse_chat_history_gemini(
|
|||||||
return [_convert_to_prompt(part) for part in raw_content]
|
return [_convert_to_prompt(part) for part in raw_content]
|
||||||
|
|
||||||
vertex_messages = []
|
vertex_messages = []
|
||||||
|
raw_system_message = None
|
||||||
for i, message in enumerate(history):
|
for i, message in enumerate(history):
|
||||||
if i == 0 and isinstance(message, SystemMessage):
|
if (
|
||||||
raise ValueError("SystemMessages are not yet supported!")
|
i == 0
|
||||||
|
and isinstance(message, SystemMessage)
|
||||||
|
and not convert_system_message_to_human
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"""SystemMessages are not yet supported!
|
||||||
|
|
||||||
|
To automatically convert the leading SystemMessage to a HumanMessage,
|
||||||
|
set `convert_system_message_to_human` to True. Example:
|
||||||
|
|
||||||
|
llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
elif i == 0 and isinstance(message, SystemMessage):
|
||||||
|
raw_system_message = message
|
||||||
|
continue
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
raw_function_call = message.additional_kwargs.get("function_call")
|
raw_function_call = message.additional_kwargs.get("function_call")
|
||||||
role = "model"
|
role = "model"
|
||||||
@ -170,6 +188,8 @@ def _parse_chat_history_gemini(
|
|||||||
)
|
)
|
||||||
gapic_part = GapicPart(function_call=function_call)
|
gapic_part = GapicPart(function_call=function_call)
|
||||||
parts = [Part._from_gapic(gapic_part)]
|
parts = [Part._from_gapic(gapic_part)]
|
||||||
|
else:
|
||||||
|
parts = _convert_to_parts(message)
|
||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
role = "user"
|
role = "user"
|
||||||
parts = _convert_to_parts(message)
|
parts = _convert_to_parts(message)
|
||||||
@ -188,6 +208,15 @@ def _parse_chat_history_gemini(
|
|||||||
f"Unexpected message with type {type(message)} at the position {i}."
|
f"Unexpected message with type {type(message)} at the position {i}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if raw_system_message:
|
||||||
|
if role == "model":
|
||||||
|
raise ValueError(
|
||||||
|
"SystemMessage should be followed by a HumanMessage and "
|
||||||
|
"not by AIMessage."
|
||||||
|
)
|
||||||
|
parts = _convert_to_parts(raw_system_message) + parts
|
||||||
|
raw_system_message = None
|
||||||
|
|
||||||
vertex_message = Content(role=role, parts=parts)
|
vertex_message = Content(role=role, parts=parts)
|
||||||
vertex_messages.append(vertex_message)
|
vertex_messages.append(vertex_message)
|
||||||
return vertex_messages
|
return vertex_messages
|
||||||
@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
model_name: str = "chat-bison"
|
model_name: str = "chat-bison"
|
||||||
"Underlying model name."
|
"Underlying model name."
|
||||||
examples: Optional[List[BaseMessage]] = None
|
examples: Optional[List[BaseMessage]] = None
|
||||||
|
convert_system_message_to_human: bool = False
|
||||||
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
||||||
|
|
||||||
|
Gemini does not support system messages; any unsupported messages will
|
||||||
|
raise an error."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(self) -> bool:
|
def is_lc_serializable(self) -> bool:
|
||||||
@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||||
|
|
||||||
if self._is_gemini_model:
|
if self._is_gemini_model:
|
||||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
history_gemini = _parse_chat_history_gemini(
|
||||||
|
messages,
|
||||||
|
project=self.project,
|
||||||
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||||
|
)
|
||||||
message = history_gemini.pop()
|
message = history_gemini.pop()
|
||||||
chat = self.client.start_chat(history=history_gemini)
|
chat = self.client.start_chat(history=history_gemini)
|
||||||
|
|
||||||
@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||||
|
|
||||||
if self._is_gemini_model:
|
if self._is_gemini_model:
|
||||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
history_gemini = _parse_chat_history_gemini(
|
||||||
|
messages,
|
||||||
|
project=self.project,
|
||||||
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||||
|
)
|
||||||
message = history_gemini.pop()
|
message = history_gemini.pop()
|
||||||
chat = self.client.start_chat(history=history_gemini)
|
chat = self.client.start_chat(history=history_gemini)
|
||||||
# set param to `functions` until core tool/function calling implemented
|
# set param to `functions` until core tool/function calling implemented
|
||||||
@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||||
if self._is_gemini_model:
|
if self._is_gemini_model:
|
||||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
history_gemini = _parse_chat_history_gemini(
|
||||||
|
messages,
|
||||||
|
project=self.project,
|
||||||
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||||
|
)
|
||||||
message = history_gemini.pop()
|
message = history_gemini.pop()
|
||||||
chat = self.client.start_chat(history=history_gemini)
|
chat = self.client.start_chat(history=history_gemini)
|
||||||
# set param to `functions` until core tool/function calling implemented
|
# set param to `functions` until core tool/function calling implemented
|
||||||
|
@ -182,3 +182,36 @@ def test_vertexai_single_call_fails_no_message() -> None:
|
|||||||
str(exc_info.value)
|
str(exc_info.value)
|
||||||
== "You should provide at least one message to start the chat!"
|
== "You should provide at least one message to start the chat!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["gemini-pro"])
|
||||||
|
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
|
||||||
|
model = ChatVertexAI(model_name=model_name)
|
||||||
|
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||||
|
text_question2 = "How much is 3+3?"
|
||||||
|
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||||
|
message1 = HumanMessage(content=text_question1)
|
||||||
|
message2 = AIMessage(content=text_answer1)
|
||||||
|
message3 = HumanMessage(content=text_question2)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model([system_message, message1, message2, message3])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||||
|
def test_chat_vertexai_system_message(model_name: str) -> None:
|
||||||
|
if model_name:
|
||||||
|
model = ChatVertexAI(
|
||||||
|
model_name=model_name, convert_system_message_to_human=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = ChatVertexAI()
|
||||||
|
|
||||||
|
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||||
|
text_question2 = "How much is 3+3?"
|
||||||
|
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||||
|
message1 = HumanMessage(content=text_question1)
|
||||||
|
message2 = AIMessage(content=text_answer1)
|
||||||
|
message3 = HumanMessage(content=text_question2)
|
||||||
|
response = model([system_message, message1, message2, message3])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
@ -13,6 +13,7 @@ from vertexai.language_models import ChatMessage, InputOutputTextPair # type: i
|
|||||||
from langchain_google_vertexai.chat_models import (
|
from langchain_google_vertexai.chat_models import (
|
||||||
ChatVertexAI,
|
ChatVertexAI,
|
||||||
_parse_chat_history,
|
_parse_chat_history,
|
||||||
|
_parse_chat_history_gemini,
|
||||||
_parse_examples,
|
_parse_examples,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,6 +113,24 @@ def test_parse_chat_history_correct() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_history_gemini() -> None:
|
||||||
|
system_input = "You're supposed to answer math questions."
|
||||||
|
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||||
|
text_question2 = "How much is 3+3?"
|
||||||
|
system_message = SystemMessage(content=system_input)
|
||||||
|
message1 = HumanMessage(content=text_question1)
|
||||||
|
message2 = AIMessage(content=text_answer1)
|
||||||
|
message3 = HumanMessage(content=text_question2)
|
||||||
|
messages = [system_message, message1, message2, message3]
|
||||||
|
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
|
||||||
|
assert len(history) == 3
|
||||||
|
assert history[0].role == "user"
|
||||||
|
assert history[0].parts[0].text == system_input
|
||||||
|
assert history[0].parts[1].text == text_question1
|
||||||
|
assert history[1].role == "model"
|
||||||
|
assert history[1].parts[0].text == text_answer1
|
||||||
|
|
||||||
|
|
||||||
def test_default_params_palm() -> None:
|
def test_default_params_palm() -> None:
|
||||||
user_prompt = "Hello"
|
user_prompt = "Hello"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user