diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index b5dea2b0f87..7a11985dc79 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -3452,6 +3452,10 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict: # string, assume image_url output = {"type": "input_image", "image_url": message.content} computer_call_output["output"] = output + if "acknowledged_safety_checks" in message.additional_kwargs: + computer_call_output["acknowledged_safety_checks"] = message.additional_kwargs[ + "acknowledged_safety_checks" + ] return computer_call_output diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 4e8f57ce7e0..d387ee92170 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -63,6 +63,7 @@ from langchain_openai.chat_models.base import ( _create_usage_metadata_responses, _format_message_content, _get_last_messages, + _make_computer_call_output_from_message, _oai_structured_outputs_parser, ) @@ -2454,3 +2455,76 @@ def test_get_request_payload_use_previous_response_id() -> None: payload = llm._get_request_payload(messages) assert "previous_response_id" not in payload assert len(payload["input"]) == 1 + + +def test_make_computer_call_output_from_message() -> None: + # List content + tool_message = ToolMessage( + content=[ + {"type": "input_image", "image_url": "data:image/png;base64,"} + ], + tool_call_id="call_abc123", + additional_kwargs={"type": "computer_call_output"}, + ) + result = _make_computer_call_output_from_message(tool_message) + + assert result == { + "type": "computer_call_output", + "call_id": "call_abc123", + "output": { + "type": "input_image", + "image_url": "data:image/png;base64,", + }, + } + + # String content + tool_message = ToolMessage( + content="data:image/png;base64,", + tool_call_id="call_abc123", + additional_kwargs={"type": "computer_call_output"}, + ) + result = _make_computer_call_output_from_message(tool_message) + + assert result == { + "type": "computer_call_output", + "call_id": "call_abc123", + "output": { + "type": "input_image", + "image_url": "data:image/png;base64,", + }, + } + + # Safety checks + tool_message = ToolMessage( + content=[ + {"type": "input_image", "image_url": "data:image/png;base64,"} + ], + tool_call_id="call_abc123", + additional_kwargs={ + "type": "computer_call_output", + "acknowledged_safety_checks": [ + { + "id": "cu_sc_abc234", + "code": "malicious_instructions", + "message": "Malicious instructions detected.", + } + ], + }, + ) + result = _make_computer_call_output_from_message(tool_message) + + assert result == { + "type": "computer_call_output", + "call_id": "call_abc123", + "output": { + "type": "input_image", + "image_url": "data:image/png;base64,", + }, + "acknowledged_safety_checks": [ + { + "id": "cu_sc_abc234", + "code": "malicious_instructions", + "message": "Malicious instructions detected.", + } + ], + }