mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
added history and support for system_message as param (#14824)
- **Description:** added support for chat_history for Google GenerativeAI (to actually use the `chat` API) plus since Gemini currently doesn't have a support for SystemMessage, added support for it only if a user provides additional `convert_system_message_to_human` flag during model initialization (in this case, SystemMessage would be prepanded to the first HumanMessage) - **Issue:** #14710 - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** lkuligin --------- Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
parent
2861766d0d
commit
2d0f1cae8c
@ -136,6 +136,32 @@
|
||||
"print(result.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9e55d043-bb2f-44e3-9134-c39a1abe3a9e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Gemini doesn't support `SystemMessage` at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to True:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7a64b523-9710-4d15-9944-1e3cc567a52b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
|
||||
"\n",
|
||||
"model = ChatGoogleGenerativeAI(model=\"gemini-pro\", convert_system_message_to_human=True)\n",
|
||||
"model(\n",
|
||||
" [\n",
|
||||
" SystemMessage(content=\"Answer only yes or no.\"),\n",
|
||||
" HumanMessage(content=\"Is apple a fruit?\"),\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "40773fac-b24d-476d-91c8-2da8fed99b53",
|
||||
|
@ -37,6 +37,7 @@ from langchain_core.messages import (
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
@ -139,7 +140,7 @@ def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
return _chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
@ -172,26 +173,6 @@ async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> An
|
||||
return await _achat_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _get_role(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
if message.role not in ("user", "model"):
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Gemini only supports user and model roles when"
|
||||
" providing it with Chat messages."
|
||||
)
|
||||
return message.role
|
||||
elif isinstance(message, HumanMessage):
|
||||
return "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "model"
|
||||
else:
|
||||
# TODO: Gemini doesn't seem to have a concept of system messages yet.
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
f"Message of '{message.type}' type not supported by Gemini."
|
||||
" Please only provide it with Human or AI (user/assistant) messages."
|
||||
)
|
||||
|
||||
|
||||
def _is_openai_parts_format(part: dict) -> bool:
|
||||
return "type" in part
|
||||
|
||||
@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image:
|
||||
|
||||
|
||||
def _convert_to_parts(
|
||||
content: Sequence[Union[str, dict]],
|
||||
raw_content: Union[str, Sequence[Union[str, dict]]],
|
||||
) -> List[genai.types.PartType]:
|
||||
"""Converts a list of LangChain messages into a google parts."""
|
||||
parts = []
|
||||
content = [raw_content] if isinstance(raw_content, str) else raw_content
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
parts.append(genai.types.PartDict(text=part, inline_data=None))
|
||||
parts.append(genai.types.PartDict(text=part))
|
||||
elif isinstance(part, Mapping):
|
||||
# OpenAI Format
|
||||
if _is_openai_parts_format(part):
|
||||
@ -304,27 +286,49 @@ def _convert_to_parts(
|
||||
return parts
|
||||
|
||||
|
||||
def _messages_to_genai_contents(
|
||||
input_messages: Sequence[BaseMessage],
|
||||
def _parse_chat_history(
|
||||
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
|
||||
) -> List[genai.types.ContentDict]:
|
||||
"""Converts a list of messages into a Gemini API google content dicts."""
|
||||
|
||||
messages: List[genai.types.MessageDict] = []
|
||||
|
||||
raw_system_message: Optional[SystemMessage] = None
|
||||
for i, message in enumerate(input_messages):
|
||||
role = _get_role(message)
|
||||
if isinstance(message.content, str):
|
||||
parts = [message.content]
|
||||
if (
|
||||
i == 0
|
||||
and isinstance(message, SystemMessage)
|
||||
and not convert_system_message_to_human
|
||||
):
|
||||
raise ValueError(
|
||||
"""SystemMessages are not yet supported!
|
||||
|
||||
To automatically convert the leading SystemMessage to a HumanMessage,
|
||||
set `convert_system_message_to_human` to True. Example:
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
|
||||
"""
|
||||
)
|
||||
elif i == 0 and isinstance(message, SystemMessage):
|
||||
raw_system_message = message
|
||||
continue
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "model"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
parts = _convert_to_parts(message.content)
|
||||
messages.append({"role": role, "parts": parts})
|
||||
if i > 0:
|
||||
# Cannot have multiple messages from the same role in a row.
|
||||
if role == messages[-2]["role"]:
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Cannot have multiple messages from the same role in a row."
|
||||
" Consider merging them into a single message with multiple"
|
||||
f" parts.\nReceived: {messages}"
|
||||
raise ValueError(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
|
||||
parts = _convert_to_parts(message.content)
|
||||
if raw_system_message:
|
||||
if role == "model":
|
||||
raise ValueError(
|
||||
"SystemMessage should be followed by a HumanMessage and "
|
||||
"not by AIMessage."
|
||||
)
|
||||
parts = _convert_to_parts(raw_system_message.content) + parts
|
||||
raw_system_message = None
|
||||
messages.append({"role": role, "parts": parts})
|
||||
return messages
|
||||
|
||||
|
||||
@ -457,8 +461,11 @@ Supported examples:
|
||||
n: int = Field(default=1, alias="candidate_count")
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
|
||||
_generative_model: Any #: :meta private:
|
||||
convert_system_message_to_human: bool = False
|
||||
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
||||
|
||||
Gemini does not support system messages; any unsupported messages will
|
||||
raise an error."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
@ -499,7 +506,7 @@ Supported examples:
|
||||
if values.get("top_k") is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
model = values["model"]
|
||||
values["_generative_model"] = genai.GenerativeModel(model_name=model)
|
||||
values["client"] = genai.GenerativeModel(model_name=model)
|
||||
return values
|
||||
|
||||
@property
|
||||
@ -512,18 +519,9 @@ Supported examples:
|
||||
"n": self.n,
|
||||
}
|
||||
|
||||
@property
|
||||
def _generation_method(self) -> Callable:
|
||||
return self._generative_model.generate_content
|
||||
|
||||
@property
|
||||
def _async_generation_method(self) -> Callable:
|
||||
return self._generative_model.generate_content_async
|
||||
|
||||
def _prepare_params(
|
||||
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any
|
||||
self, stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
contents = _messages_to_genai_contents(messages)
|
||||
gen_config = {
|
||||
k: v
|
||||
for k, v in {
|
||||
@ -538,7 +536,7 @@ Supported examples:
|
||||
}
|
||||
if "generation_config" in kwargs:
|
||||
gen_config = {**gen_config, **kwargs.pop("generation_config")}
|
||||
params = {"generation_config": gen_config, "contents": contents, **kwargs}
|
||||
params = {"generation_config": gen_config, **kwargs}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
@ -548,10 +546,11 @@ Supported examples:
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
generation_method=chat.send_message,
|
||||
)
|
||||
return _response_to_result(response)
|
||||
|
||||
@ -562,10 +561,11 @@ Supported examples:
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
generation_method=chat.send_message_async,
|
||||
)
|
||||
return _response_to_result(response)
|
||||
|
||||
@ -576,10 +576,11 @@ Supported examples:
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
generation_method=chat.send_message,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
@ -602,10 +603,11 @@ Supported examples:
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
async for chunk in await _achat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
generation_method=chat.send_message_async,
|
||||
stream=True,
|
||||
):
|
||||
_chat_result = _response_to_result(
|
||||
@ -619,3 +621,18 @@ Supported examples:
|
||||
yield gen
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(gen.text)
|
||||
|
||||
def _prepare_chat(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
||||
params = self._prepare_params(stop, **kwargs)
|
||||
history = _parse_chat_history(
|
||||
messages,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history.pop()
|
||||
chat = self.client.start_chat(history=history)
|
||||
return params, chat, message
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test ChatGoogleGenerativeAI chat model."""
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_google_genai.chat_models import (
|
||||
ChatGoogleGenerativeAI,
|
||||
@ -147,3 +147,40 @@ def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
with pytest.raises(ChatGoogleGenerativeAIError):
|
||||
llm.invoke(messages)
|
||||
|
||||
|
||||
def test_chat_google_genai_single_call_with_history() -> None:
|
||||
model = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
response = model([message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_google_genai_system_message_error() -> None:
|
||||
model = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
with pytest.raises(ValueError):
|
||||
model([system_message, message1, message2, message3])
|
||||
|
||||
|
||||
def test_chat_google_genai_system_message() -> None:
|
||||
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content="You're supposed to answer math questions.")
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
response = model([system_message, message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
@ -1,8 +1,12 @@
|
||||
"""Test chat model integration."""
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai.chat_models import (
|
||||
ChatGoogleGenerativeAI,
|
||||
_parse_chat_history,
|
||||
)
|
||||
|
||||
|
||||
def test_integration_initialization() -> None:
|
||||
@ -36,3 +40,21 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_parse_history() -> None:
|
||||
system_input = "You're supposed to answer math questions."
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
system_message = SystemMessage(content=system_input)
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
messages = [system_message, message1, message2, message3]
|
||||
history = _parse_chat_history(messages, convert_system_message_to_human=True)
|
||||
assert len(history) == 3
|
||||
assert history[0] == {
|
||||
"role": "user",
|
||||
"parts": [{"text": system_input}, {"text": text_question1}],
|
||||
}
|
||||
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
|
||||
|
Loading…
Reference in New Issue
Block a user