added history and support for system_message as param (#14824)

- **Description:** added support for chat_history for Google
GenerativeAI (to actually use the `chat` API) plus since Gemini
currently doesn't have a support for SystemMessage, added support for it
only if a user provides additional `convert_system_message_to_human`
flag during model initialization (in this case, SystemMessage would be
prepanded to the first HumanMessage)
  - **Issue:** #14710 
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
  - **Twitter handle:** lkuligin

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Leonid Kuligin 2023-12-19 03:23:14 +01:00 committed by GitHub
parent 2861766d0d
commit 2d0f1cae8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 166 additions and 64 deletions

View File

@ -136,6 +136,32 @@
"print(result.content)" "print(result.content)"
] ]
}, },
{
"cell_type": "markdown",
"id": "9e55d043-bb2f-44e3-9134-c39a1abe3a9e",
"metadata": {},
"source": [
"Gemini doesn't support `SystemMessage` at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to True:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a64b523-9710-4d15-9944-1e3cc567a52b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
"\n",
"model = ChatGoogleGenerativeAI(model=\"gemini-pro\", convert_system_message_to_human=True)\n",
"model(\n",
" [\n",
" SystemMessage(content=\"Answer only yes or no.\"),\n",
" HumanMessage(content=\"Is apple a fruit?\"),\n",
" ]\n",
")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "40773fac-b24d-476d-91c8-2da8fed99b53", "id": "40773fac-b24d-476d-91c8-2da8fed99b53",

View File

@ -37,6 +37,7 @@ from langchain_core.messages import (
ChatMessageChunk, ChatMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
SystemMessage,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
) )
def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
""" """
Executes a chat generation method with retry logic using tenacity. Executes a chat generation method with retry logic using tenacity.
@ -139,7 +140,7 @@ def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
return _chat_with_retry(**kwargs) return _chat_with_retry(**kwargs)
async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
""" """
Executes a chat generation method with retry logic using tenacity. Executes a chat generation method with retry logic using tenacity.
@ -172,26 +173,6 @@ async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> An
return await _achat_with_retry(**kwargs) return await _achat_with_retry(**kwargs)
def _get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
if message.role not in ("user", "model"):
raise ChatGoogleGenerativeAIError(
"Gemini only supports user and model roles when"
" providing it with Chat messages."
)
return message.role
elif isinstance(message, HumanMessage):
return "user"
elif isinstance(message, AIMessage):
return "model"
else:
# TODO: Gemini doesn't seem to have a concept of system messages yet.
raise ChatGoogleGenerativeAIError(
f"Message of '{message.type}' type not supported by Gemini."
" Please only provide it with Human or AI (user/assistant) messages."
)
def _is_openai_parts_format(part: dict) -> bool: def _is_openai_parts_format(part: dict) -> bool:
return "type" in part return "type" in part
@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image:
def _convert_to_parts( def _convert_to_parts(
content: Sequence[Union[str, dict]], raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[genai.types.PartType]: ) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts.""" """Converts a list of LangChain messages into a google parts."""
parts = [] parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
for part in content: for part in content:
if isinstance(part, str): if isinstance(part, str):
parts.append(genai.types.PartDict(text=part, inline_data=None)) parts.append(genai.types.PartDict(text=part))
elif isinstance(part, Mapping): elif isinstance(part, Mapping):
# OpenAI Format # OpenAI Format
if _is_openai_parts_format(part): if _is_openai_parts_format(part):
@ -304,27 +286,49 @@ def _convert_to_parts(
return parts return parts
def _messages_to_genai_contents( def _parse_chat_history(
input_messages: Sequence[BaseMessage], input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> List[genai.types.ContentDict]: ) -> List[genai.types.ContentDict]:
"""Converts a list of messages into a Gemini API google content dicts."""
messages: List[genai.types.MessageDict] = [] messages: List[genai.types.MessageDict] = []
raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages): for i, message in enumerate(input_messages):
role = _get_role(message) if (
if isinstance(message.content, str): i == 0
parts = [message.content] and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else: else:
parts = _convert_to_parts(message.content) raise ValueError(
messages.append({"role": role, "parts": parts}) f"Unexpected message with type {type(message)} at the position {i}."
if i > 0: )
# Cannot have multiple messages from the same role in a row.
if role == messages[-2]["role"]: parts = _convert_to_parts(message.content)
raise ChatGoogleGenerativeAIError( if raw_system_message:
"Cannot have multiple messages from the same role in a row." if role == "model":
" Consider merging them into a single message with multiple" raise ValueError(
f" parts.\nReceived: {messages}" "SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
) )
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages return messages
@ -457,8 +461,11 @@ Supported examples:
n: int = Field(default=1, alias="candidate_count") n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may """Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated.""" not return the full n completions if duplicates are generated."""
convert_system_message_to_human: bool = False
_generative_model: Any #: :meta private: """Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
@ -499,7 +506,7 @@ Supported examples:
if values.get("top_k") is not None and values["top_k"] <= 0: if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive") raise ValueError("top_k must be positive")
model = values["model"] model = values["model"]
values["_generative_model"] = genai.GenerativeModel(model_name=model) values["client"] = genai.GenerativeModel(model_name=model)
return values return values
@property @property
@ -512,18 +519,9 @@ Supported examples:
"n": self.n, "n": self.n,
} }
@property
def _generation_method(self) -> Callable:
return self._generative_model.generate_content
@property
def _async_generation_method(self) -> Callable:
return self._generative_model.generate_content_async
def _prepare_params( def _prepare_params(
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any self, stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
contents = _messages_to_genai_contents(messages)
gen_config = { gen_config = {
k: v k: v
for k, v in { for k, v in {
@ -538,7 +536,7 @@ Supported examples:
} }
if "generation_config" in kwargs: if "generation_config" in kwargs:
gen_config = {**gen_config, **kwargs.pop("generation_config")} gen_config = {**gen_config, **kwargs.pop("generation_config")}
params = {"generation_config": gen_config, "contents": contents, **kwargs} params = {"generation_config": gen_config, **kwargs}
return params return params
def _generate( def _generate(
@ -548,10 +546,11 @@ Supported examples:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs) params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry( response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params, **params,
generation_method=self._generation_method, generation_method=chat.send_message,
) )
return _response_to_result(response) return _response_to_result(response)
@ -562,10 +561,11 @@ Supported examples:
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs) params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = await _achat_with_retry( response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
**params, **params,
generation_method=self._async_generation_method, generation_method=chat.send_message_async,
) )
return _response_to_result(response) return _response_to_result(response)
@ -576,10 +576,11 @@ Supported examples:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs) params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry( response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params, **params,
generation_method=self._generation_method, generation_method=chat.send_message,
stream=True, stream=True,
) )
for chunk in response: for chunk in response:
@ -602,10 +603,11 @@ Supported examples:
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs) params, chat, message = self._prepare_chat(messages, stop=stop)
async for chunk in await _achat_with_retry( async for chunk in await _achat_with_retry(
content=message,
**params, **params,
generation_method=self._async_generation_method, generation_method=chat.send_message_async,
stream=True, stream=True,
): ):
_chat_result = _response_to_result( _chat_result = _response_to_result(
@ -619,3 +621,18 @@ Supported examples:
yield gen yield gen
if run_manager: if run_manager:
await run_manager.on_llm_new_token(gen.text) await run_manager.on_llm_new_token(gen.text)
def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = self.client.start_chat(history=history)
return params, chat, message

View File

@ -1,6 +1,6 @@
"""Test ChatGoogleGenerativeAI chat model.""" """Test ChatGoogleGenerativeAI chat model."""
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_google_genai.chat_models import ( from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI, ChatGoogleGenerativeAI,
@ -147,3 +147,40 @@ def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
llm = ChatGoogleGenerativeAI(model=_MODEL) llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError): with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages) llm.invoke(messages)
def test_chat_google_genai_single_call_with_history() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

View File

@ -1,8 +1,12 @@
"""Test chat model integration.""" """Test chat model integration."""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import SecretStr from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture from pytest import CaptureFixture
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
_parse_chat_history,
)
def test_integration_initialization() -> None: def test_integration_initialization() -> None:
@ -36,3 +40,21 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "**********" assert captured.out == "**********"
def test_parse_history() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}