addressing comments

This commit is contained in:
Mason Daugherty 2025-08-06 18:19:07 -04:00
parent 73c49d31d6
commit adff6f48db
No known key found for this signature in database
2 changed files with 59 additions and 108 deletions

View File

@ -483,12 +483,46 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
def _convert_from_v0_to_v1(message: BaseMessage) -> MessageV1: def _convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
"""Convert a v0 message to a v1 message."""
if isinstance(message, HumanMessage): # Checking for v0 HumanMessage
return HumanMessageV1(message.content, id=message.id, name=message.name) # type: ignore[arg-type]
if isinstance(message, AIMessage): # Checking for v0 AIMessage
return AIMessageV1(
content=message.content, # type: ignore[arg-type]
id=message.id,
name=message.name,
lc_version="v1",
response_metadata=message.response_metadata, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
tool_calls=message.tool_calls,
invalid_tool_calls=message.invalid_tool_calls,
)
if isinstance(message, SystemMessage): # Checking for v0 SystemMessage
return SystemMessageV1(
message.content, # type: ignore[arg-type]
id=message.id,
name=message.name,
)
if isinstance(message, ToolMessage): # Checking for v0 ToolMessage
return ToolMessageV1(
message.content, # type: ignore[arg-type]
message.tool_call_id,
id=message.id,
name=message.name,
artifact=message.artifact,
status=message.status,
)
msg = f"Unsupported v0 message type for conversion to v1: {type(message)}"
raise NotImplementedError(msg)
def _safe_convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
"""Convert a v0 message to a v1 message.""" """Convert a v0 message to a v1 message."""
from langchain_core.messages.content_blocks import create_text_block from langchain_core.messages.content_blocks import create_text_block
if isinstance(message, HumanMessage): # Checking for v0 HumanMessage if isinstance(message, HumanMessage): # Checking for v0 HumanMessage
content: list[ContentBlock] = [create_text_block(str(message.content))] content: list[ContentBlock] = [create_text_block(str(message.content))]
return HumanMessageV1(content, name=message.name) return HumanMessageV1(content, id=message.id, name=message.name)
if isinstance(message, AIMessage): # Checking for v0 AIMessage if isinstance(message, AIMessage): # Checking for v0 AIMessage
content = [create_text_block(str(message.content))] content = [create_text_block(str(message.content))]
@ -497,20 +531,25 @@ def _convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
response_metadata = cast("ResponseMetadata", message.response_metadata or {}) response_metadata = cast("ResponseMetadata", message.response_metadata or {})
return AIMessageV1( return AIMessageV1(
content=content, content=content,
id=message.id,
name=message.name, name=message.name,
usage_metadata=message.usage_metadata, lc_version="v1",
response_metadata=response_metadata, response_metadata=response_metadata,
usage_metadata=message.usage_metadata,
tool_calls=message.tool_calls, tool_calls=message.tool_calls,
invalid_tool_calls=message.invalid_tool_calls,
) )
if isinstance(message, SystemMessage): # Checking for v0 SystemMessage if isinstance(message, SystemMessage): # Checking for v0 SystemMessage
content = [create_text_block(str(message.content))] content = [create_text_block(str(message.content))]
return SystemMessageV1(content=content, name=message.name) return SystemMessageV1(content=content, id=message.id, name=message.name)
if isinstance(message, ToolMessage): # Checking for v0 ToolMessage if isinstance(message, ToolMessage): # Checking for v0 ToolMessage
content = [create_text_block(str(message.content))] content = [create_text_block(str(message.content))]
return ToolMessageV1( return ToolMessageV1(
content=content, content,
message.tool_call_id,
id=message.id,
name=message.name, name=message.name,
tool_call_id=message.tool_call_id, artifact=message.artifact,
status=message.status, status=message.status,
) )
msg = f"Unsupported v0 message type for conversion to v1: {type(message)}" msg = f"Unsupported v0 message type for conversion to v1: {type(message)}"
@ -553,49 +592,23 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
message_type_str, template = message # type: ignore[misc] message_type_str, template = message # type: ignore[misc]
message_ = _create_message_from_message_type_v1(message_type_str, template) message_ = _create_message_from_message_type_v1(message_type_str, template)
elif isinstance(message, dict): elif isinstance(message, dict):
# Handle ToolCall content blocks passed as messages msg_kwargs = message.copy()
if ( try:
message.get("type") == "tool_call"
and "name" in message
and "args" in message
and "id" in message
and "content" not in message
):
# Convert ToolCall content block to an AIMessage with the tool call
from langchain_core.messages.content_blocks import (
ToolCall,
create_text_block,
)
from langchain_core.v1.messages import AIMessage
tool_call = ToolCall(
type="tool_call",
name=message["name"],
args=message["args"],
id=message["id"],
)
message_ = AIMessage(content=[create_text_block(""), tool_call])
else:
msg_kwargs = message.copy()
try: try:
try: msg_type = msg_kwargs.pop("role")
msg_type = msg_kwargs.pop("role") except KeyError:
except KeyError: msg_type = msg_kwargs.pop("type")
msg_type = msg_kwargs.pop("type") # None msg content is not allowed
# None msg content is not allowed msg_content = msg_kwargs.pop("content") or ""
msg_content = msg_kwargs.pop("content") or "" except KeyError as e:
except KeyError as e: msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
msg = ( msg = create_message(
"Message dict must contain 'role' and 'content' " message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
f"keys, got {message}"
)
msg = create_message(
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
)
raise ValueError(msg) from e
message_ = _create_message_from_message_type_v1(
msg_type, msg_content, **msg_kwargs
) )
raise ValueError(msg) from e
message_ = _create_message_from_message_type_v1(
msg_type, msg_content, **msg_kwargs
)
else: else:
msg = f"Unsupported message type: {type(message)}" msg = f"Unsupported message type: {type(message)}"
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)

View File

@ -1176,68 +1176,6 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
) # TODO: do we need to handle if args is str? # noqa: E501 ) # TODO: do we need to handle if args is str? # noqa: E501
assert is_tool_call_block(tool_call) assert is_tool_call_block(tool_call)
def test_tool_message_histories_string_content(
self, model: BaseChatModel, my_adder_tool: BaseTool
) -> None:
"""Test that message histories are compatible with string tool contents
(e.g. OpenAI format). If a model passes this test, it should be compatible
with messages generated from providers following OpenAI format.
This test should be skipped if the model does not support tool calling
(see Configuration below).
.. dropdown:: Configuration
To disable tool calling tests, set ``has_tool_calling`` to False in your
test class:
.. code-block:: python
class TestMyV1ChatModelIntegration(ChatModelV1IntegrationTests):
@property
def has_tool_calling(self) -> bool:
return False
.. dropdown:: Troubleshooting
TODO: verify this!
If this test fails, check that:
1. The model can correctly handle message histories that include ``AIMessage`` objects with ``""`` ``TextContentBlock``s.
2. The ``tool_calls`` attribute on ``AIMessage`` objects is correctly handled and passed to the model in an appropriate format.
3. The model can correctly handle ``ToolMessage`` objects with string content and arbitrary string values for ``tool_call_id``.
You can ``xfail`` the test if tool calling is implemented but this format
is not supported.
.. code-block:: python
@pytest.mark.xfail(reason=("Not implemented."))
def test_tool_message_histories_string_content(self, *args: Any) -> None:
super().test_tool_message_histories_string_content(*args)
""" # noqa: E501
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"}
messages_string_content = [
HumanMessage("What is 1 + 2"),
# String content (e.g. OpenAI)
create_tool_call(function_name, function_args, id="abc123"),
ToolMessage(
json.dumps({"result": 3}), tool_call_id="abc123", status="success"
),
]
result_string_content = model_with_tools.invoke(
messages_string_content # type: ignore[arg-type]
) # TODO
assert isinstance(result_string_content, AIMessage)
def test_tool_message_histories_list_content( def test_tool_message_histories_list_content(
self, self,
model: BaseChatModel, model: BaseChatModel,