mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 19:15:44 +00:00
mistral, openai: allow anthropic-style messages in message histories (#20565)
This commit is contained in:
parent
7a7851aa06
commit
2238490069
@ -283,9 +283,16 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
tool_calls.append(chunk)
|
tool_calls.append(chunk)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
if tool_calls and message.content:
|
||||||
|
# Assistant message must have either content or tool_calls, but not both.
|
||||||
|
# Some providers may not support tool_calls in the same message as content.
|
||||||
|
# This is done to ensure compatibility with messages from other providers.
|
||||||
|
content: Any = ""
|
||||||
|
else:
|
||||||
|
content = message.content
|
||||||
return {
|
return {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": message.content,
|
"content": content,
|
||||||
"tool_calls": tool_calls,
|
"tool_calls": tool_calls,
|
||||||
}
|
}
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
|
@ -148,6 +148,22 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)
|
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_message_content(content: Any) -> Any:
|
||||||
|
"""Format message content."""
|
||||||
|
if content and isinstance(content, list):
|
||||||
|
# Remove unexpected block types
|
||||||
|
formatted_content = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and "type" in block and block["type"] != "text":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
formatted_content.append(block)
|
||||||
|
else:
|
||||||
|
formatted_content = content
|
||||||
|
|
||||||
|
return formatted_content
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
"""Convert a LangChain message to a dictionary.
|
"""Convert a LangChain message to a dictionary.
|
||||||
|
|
||||||
@ -158,7 +174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
The dictionary.
|
The dictionary.
|
||||||
"""
|
"""
|
||||||
message_dict: Dict[str, Any] = {
|
message_dict: Dict[str, Any] = {
|
||||||
"content": message.content,
|
"content": _format_message_content(message.content),
|
||||||
}
|
}
|
||||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||||
message_dict["name"] = name
|
message_dict["name"] = name
|
||||||
|
@ -117,12 +117,13 @@ class ChatModelIntegrationTests(ABC):
|
|||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
assert len(result.content) > 0
|
assert len(result.content) > 0
|
||||||
|
|
||||||
def test_tool_message(
|
def test_tool_message_histories(
|
||||||
self,
|
self,
|
||||||
chat_model_class: Type[BaseChatModel],
|
chat_model_class: Type[BaseChatModel],
|
||||||
chat_model_params: dict,
|
chat_model_params: dict,
|
||||||
chat_model_has_tool_calling: bool,
|
chat_model_has_tool_calling: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Test that message histories are compatible across providers."""
|
||||||
if not chat_model_has_tool_calling:
|
if not chat_model_has_tool_calling:
|
||||||
pytest.skip("Test requires tool calling.")
|
pytest.skip("Test requires tool calling.")
|
||||||
model = chat_model_class(**chat_model_params)
|
model = chat_model_class(**chat_model_params)
|
||||||
@ -130,9 +131,15 @@ class ChatModelIntegrationTests(ABC):
|
|||||||
function_name = "my_adder_tool"
|
function_name = "my_adder_tool"
|
||||||
function_args = {"a": "1", "b": "2"}
|
function_args = {"a": "1", "b": "2"}
|
||||||
|
|
||||||
messages = [
|
human_message = HumanMessage(content="What is 1 + 2")
|
||||||
HumanMessage(content="What is 1 + 2"),
|
tool_message = ToolMessage(
|
||||||
AIMessage(
|
name=function_name,
|
||||||
|
content=json.dumps({"result": 3}),
|
||||||
|
tool_call_id="abc123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# String content (e.g., OpenAI)
|
||||||
|
string_content_msg = AIMessage(
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
@ -141,13 +148,38 @@ class ChatModelIntegrationTests(ABC):
|
|||||||
"id": "abc123",
|
"id": "abc123",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
),
|
)
|
||||||
ToolMessage(
|
messages = [
|
||||||
name=function_name,
|
human_message,
|
||||||
content=json.dumps({"result": 3}),
|
string_content_msg,
|
||||||
tool_call_id="abc123",
|
tool_message,
|
||||||
),
|
]
|
||||||
|
result = model_with_tools.invoke(messages)
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
|
||||||
|
# List content (e.g., Anthropic)
|
||||||
|
list_content_msg = AIMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "some text"},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "abc123",
|
||||||
|
"name": function_name,
|
||||||
|
"input": function_args,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": function_name,
|
||||||
|
"args": function_args,
|
||||||
|
"id": "abc123",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
human_message,
|
||||||
|
list_content_msg,
|
||||||
|
tool_message,
|
||||||
]
|
]
|
||||||
|
|
||||||
result = model_with_tools.invoke(messages)
|
result = model_with_tools.invoke(messages)
|
||||||
assert isinstance(result, AIMessage)
|
assert isinstance(result, AIMessage)
|
||||||
|
Loading…
Reference in New Issue
Block a user