From e9286723067feaba19b2a2cf5788f41df4578cc4 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 10 Jul 2025 13:42:00 -0400 Subject: [PATCH] fix image generation --- .../langchain_openai/chat_models/_compat.py | 21 +++- .../langchain_openai/chat_models/base.py | 18 +++ .../chat_models/test_responses_api.py | 105 +++++++++++++----- 3 files changed, 113 insertions(+), 31 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index 2402fb8063d..5fcc7cba160 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -291,8 +291,8 @@ def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage: for tool_call in message.tool_calls: if id_ := tool_call.get("id"): - tool_callblock: ToolCallContentBlock = {"type": "tool_call", "id": id_} - message.content.append(tool_callblock) + tool_call_block: ToolCallContentBlock = {"type": "tool_call", "id": id_} + message.content.append(tool_call_block) if "tool_calls" in message.additional_kwargs: _ = message.additional_kwargs.pop("tool_calls") @@ -413,7 +413,18 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: "source_type": "base64", "data": result, } - for extra_key in ("id", "status"): + if output_format := block.get("output_format"): + new_block["mime_type"] = f"image/{output_format}" + for extra_key in ( + "id", + "index", + "status", + "background", + "output_format", + "quality", + "revised_prompt", + "size", + ): if extra_key in block: new_block[extra_key] = block[extra_key] yield new_block @@ -421,11 +432,11 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: elif block_type == "function_call": new_block: ToolCallContentBlock = { "type": "tool_call", - "id": block["call_id"], + "id": block.get("call_id", ""), } if "id" in block: new_block["item_id"] = block["id"] - for extra_key in ("arguments", "name"): + for extra_key in ("arguments", "name", "index"): if extra_key in block: new_block[extra_key] = block[extra_key] yield new_block diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index cb742b80bf1..0b1f3ef9604 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -3793,6 +3793,24 @@ def _construct_lc_result_from_responses_api( message = _convert_to_v03_ai_message(message) elif output_version == "v1": message = _convert_to_v1_from_responses(message) + if response.tools and any( + tool.type == "image_generation" for tool in response.tools + ): + # Get mime_time from tool definition and add to image generations + # if missing (primarily for tracing purposes). + image_generation_call = next( + tool for tool in response.tools if tool.type == "image_generation" + ) + if image_generation_call.output_format: + mime_type = f"image/{image_generation_call.output_format}" + for block in message.content: + # OK to mutate output message + if ( + block.get("type") == "image" + and block["source_type"] == "base64" + and "mime_type" not in block + ): + block["mime_type"] = mime_type else: pass return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index 6b4c9d093fa..fac526c382b 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -569,10 +569,14 @@ def test_mcp_builtin_zdr() -> None: _ = llm_with_tools.invoke([input_message, full, approval_message]) -@pytest.mark.vcr() -def test_image_generation_streaming() -> None: +@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_image_generation_streaming(output_version: str) -> None: """Test image generation streaming.""" - llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True) + llm = ChatOpenAI( + model="gpt-4.1", use_responses_api=True, output_version=output_version + ) tool = { "type": "image_generation", # For testing purposes let's keep the quality low, so the test runs faster. @@ -619,15 +623,35 @@ def test_image_generation_streaming() -> None: # At the moment, the streaming API does not pick up annotations fully. # So the following check is commented out. # _check_response(complete_ai_message) - tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0] - assert set(tool_output.keys()).issubset(expected_keys) + if output_version == "v0": + assert complete_ai_message.additional_kwargs["tool_outputs"] + tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0] + assert set(tool_output.keys()).issubset(expected_keys) + elif output_version == "responses/v1": + tool_output = next( + block + for block in complete_ai_message.content + if block["type"] == "image_generation_call" + ) + assert set(tool_output.keys()).issubset(expected_keys) + else: + # v1 + standard_keys = {"type", "source_type", "data", "id", "status", "index"} + tool_output = next( + block for block in complete_ai_message.content if block["type"] == "image" + ) + assert set(standard_keys).issubset(tool_output.keys()) -@pytest.mark.vcr() -def test_image_generation_multi_turn() -> None: +@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) +def test_image_generation_multi_turn(output_version: str) -> None: """Test multi-turn editing of image generation by passing in history.""" # Test multi-turn - llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True) + llm = ChatOpenAI( + model="gpt-4.1", use_responses_api=True, output_version=output_version + ) # Test invocation tool = { "type": "image_generation", @@ -644,9 +668,37 @@ def test_image_generation_multi_turn() -> None: ] ai_message = llm_with_tools.invoke(chat_history) _check_response(ai_message) - tool_output = ai_message.additional_kwargs["tool_outputs"][0] - # Example tool output for an image + expected_keys = { + "id", + "background", + "output_format", + "quality", + "result", + "revised_prompt", + "size", + "status", + "type", + } + + if output_version == "v0": + tool_output = ai_message.additional_kwargs["tool_outputs"][0] + assert set(tool_output.keys()).issubset(expected_keys) + elif output_version == "responses/v1": + tool_output = next( + block + for block in ai_message.content + if block["type"] == "image_generation_call" + ) + assert set(tool_output.keys()).issubset(expected_keys) + else: + standard_keys = {"type", "source_type", "data", "id", "status"} + tool_output = next( + block for block in ai_message.content if block["type"] == "image" + ) + assert set(standard_keys).issubset(tool_output.keys()) + + # Example tool output for an image (v0) # { # "background": "opaque", # "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8", @@ -662,20 +714,6 @@ def test_image_generation_multi_turn() -> None: # "result": # base64 encode image data # } - expected_keys = { - "id", - "background", - "output_format", - "quality", - "result", - "revised_prompt", - "size", - "status", - "type", - } - - assert set(tool_output.keys()).issubset(expected_keys) - chat_history.extend( [ # AI message with tool output @@ -693,5 +731,20 @@ def test_image_generation_multi_turn() -> None: ai_message2 = llm_with_tools.invoke(chat_history) _check_response(ai_message2) - tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0] - assert set(tool_output2.keys()).issubset(expected_keys) + + if output_version == "v0": + tool_output = ai_message2.additional_kwargs["tool_outputs"][0] + assert set(tool_output.keys()).issubset(expected_keys) + elif output_version == "responses/v1": + tool_output = next( + block + for block in ai_message2.content + if block["type"] == "image_generation_call" + ) + assert set(tool_output.keys()).issubset(expected_keys) + else: + standard_keys = {"type", "source_type", "data", "id", "status"} + tool_output = next( + block for block in ai_message2.content if block["type"] == "image" + ) + assert set(standard_keys).issubset(tool_output.keys())