fix image generation

This commit is contained in:
Chester Curme 2025-07-10 13:42:00 -04:00
parent 67fc58011a
commit e928672306
3 changed files with 113 additions and 31 deletions

View File

@ -291,8 +291,8 @@ def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage:
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
if id_ := tool_call.get("id"): if id_ := tool_call.get("id"):
tool_callblock: ToolCallContentBlock = {"type": "tool_call", "id": id_} tool_call_block: ToolCallContentBlock = {"type": "tool_call", "id": id_}
message.content.append(tool_callblock) message.content.append(tool_call_block)
if "tool_calls" in message.additional_kwargs: if "tool_calls" in message.additional_kwargs:
_ = message.additional_kwargs.pop("tool_calls") _ = message.additional_kwargs.pop("tool_calls")
@ -413,7 +413,18 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
"source_type": "base64", "source_type": "base64",
"data": result, "data": result,
} }
for extra_key in ("id", "status"): if output_format := block.get("output_format"):
new_block["mime_type"] = f"image/{output_format}"
for extra_key in (
"id",
"index",
"status",
"background",
"output_format",
"quality",
"revised_prompt",
"size",
):
if extra_key in block: if extra_key in block:
new_block[extra_key] = block[extra_key] new_block[extra_key] = block[extra_key]
yield new_block yield new_block
@ -421,11 +432,11 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
elif block_type == "function_call": elif block_type == "function_call":
new_block: ToolCallContentBlock = { new_block: ToolCallContentBlock = {
"type": "tool_call", "type": "tool_call",
"id": block["call_id"], "id": block.get("call_id", ""),
} }
if "id" in block: if "id" in block:
new_block["item_id"] = block["id"] new_block["item_id"] = block["id"]
for extra_key in ("arguments", "name"): for extra_key in ("arguments", "name", "index"):
if extra_key in block: if extra_key in block:
new_block[extra_key] = block[extra_key] new_block[extra_key] = block[extra_key]
yield new_block yield new_block

View File

@ -3793,6 +3793,24 @@ def _construct_lc_result_from_responses_api(
message = _convert_to_v03_ai_message(message) message = _convert_to_v03_ai_message(message)
elif output_version == "v1": elif output_version == "v1":
message = _convert_to_v1_from_responses(message) message = _convert_to_v1_from_responses(message)
if response.tools and any(
tool.type == "image_generation" for tool in response.tools
):
# Get mime_time from tool definition and add to image generations
# if missing (primarily for tracing purposes).
image_generation_call = next(
tool for tool in response.tools if tool.type == "image_generation"
)
if image_generation_call.output_format:
mime_type = f"image/{image_generation_call.output_format}"
for block in message.content:
# OK to mutate output message
if (
block.get("type") == "image"
and block["source_type"] == "base64"
and "mime_type" not in block
):
block["mime_type"] = mime_type
else: else:
pass pass
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -569,10 +569,14 @@ def test_mcp_builtin_zdr() -> None:
_ = llm_with_tools.invoke([input_message, full, approval_message]) _ = llm_with_tools.invoke([input_message, full, approval_message])
@pytest.mark.vcr() @pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
def test_image_generation_streaming() -> None: @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_image_generation_streaming(output_version: str) -> None:
"""Test image generation streaming.""" """Test image generation streaming."""
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True) llm = ChatOpenAI(
model="gpt-4.1", use_responses_api=True, output_version=output_version
)
tool = { tool = {
"type": "image_generation", "type": "image_generation",
# For testing purposes let's keep the quality low, so the test runs faster. # For testing purposes let's keep the quality low, so the test runs faster.
@ -619,15 +623,35 @@ def test_image_generation_streaming() -> None:
# At the moment, the streaming API does not pick up annotations fully. # At the moment, the streaming API does not pick up annotations fully.
# So the following check is commented out. # So the following check is commented out.
# _check_response(complete_ai_message) # _check_response(complete_ai_message)
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0] if output_version == "v0":
assert set(tool_output.keys()).issubset(expected_keys) assert complete_ai_message.additional_kwargs["tool_outputs"]
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
elif output_version == "responses/v1":
tool_output = next(
block
for block in complete_ai_message.content
if block["type"] == "image_generation_call"
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
# v1
standard_keys = {"type", "source_type", "data", "id", "status", "index"}
tool_output = next(
block for block in complete_ai_message.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
@pytest.mark.vcr() @pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
def test_image_generation_multi_turn() -> None: @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_image_generation_multi_turn(output_version: str) -> None:
"""Test multi-turn editing of image generation by passing in history.""" """Test multi-turn editing of image generation by passing in history."""
# Test multi-turn # Test multi-turn
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True) llm = ChatOpenAI(
model="gpt-4.1", use_responses_api=True, output_version=output_version
)
# Test invocation # Test invocation
tool = { tool = {
"type": "image_generation", "type": "image_generation",
@ -644,9 +668,37 @@ def test_image_generation_multi_turn() -> None:
] ]
ai_message = llm_with_tools.invoke(chat_history) ai_message = llm_with_tools.invoke(chat_history)
_check_response(ai_message) _check_response(ai_message)
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
# Example tool output for an image expected_keys = {
"id",
"background",
"output_format",
"quality",
"result",
"revised_prompt",
"size",
"status",
"type",
}
if output_version == "v0":
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
elif output_version == "responses/v1":
tool_output = next(
block
for block in ai_message.content
if block["type"] == "image_generation_call"
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
standard_keys = {"type", "source_type", "data", "id", "status"}
tool_output = next(
block for block in ai_message.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
# Example tool output for an image (v0)
# { # {
# "background": "opaque", # "background": "opaque",
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8", # "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
@ -662,20 +714,6 @@ def test_image_generation_multi_turn() -> None:
# "result": # base64 encode image data # "result": # base64 encode image data
# } # }
expected_keys = {
"id",
"background",
"output_format",
"quality",
"result",
"revised_prompt",
"size",
"status",
"type",
}
assert set(tool_output.keys()).issubset(expected_keys)
chat_history.extend( chat_history.extend(
[ [
# AI message with tool output # AI message with tool output
@ -693,5 +731,20 @@ def test_image_generation_multi_turn() -> None:
ai_message2 = llm_with_tools.invoke(chat_history) ai_message2 = llm_with_tools.invoke(chat_history)
_check_response(ai_message2) _check_response(ai_message2)
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
assert set(tool_output2.keys()).issubset(expected_keys) if output_version == "v0":
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
elif output_version == "responses/v1":
tool_output = next(
block
for block in ai_message2.content
if block["type"] == "image_generation_call"
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
standard_keys = {"type", "source_type", "data", "id", "status"}
tool_output = next(
block for block in ai_message2.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())