fix image generation

This commit is contained in:
Chester Curme 2025-07-10 13:42:00 -04:00
parent 67fc58011a
commit e928672306
3 changed files with 113 additions and 31 deletions

View File

@ -291,8 +291,8 @@ def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage:
for tool_call in message.tool_calls:
if id_ := tool_call.get("id"):
tool_callblock: ToolCallContentBlock = {"type": "tool_call", "id": id_}
message.content.append(tool_callblock)
tool_call_block: ToolCallContentBlock = {"type": "tool_call", "id": id_}
message.content.append(tool_call_block)
if "tool_calls" in message.additional_kwargs:
_ = message.additional_kwargs.pop("tool_calls")
@ -413,7 +413,18 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
"source_type": "base64",
"data": result,
}
for extra_key in ("id", "status"):
if output_format := block.get("output_format"):
new_block["mime_type"] = f"image/{output_format}"
for extra_key in (
"id",
"index",
"status",
"background",
"output_format",
"quality",
"revised_prompt",
"size",
):
if extra_key in block:
new_block[extra_key] = block[extra_key]
yield new_block
@ -421,11 +432,11 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
elif block_type == "function_call":
new_block: ToolCallContentBlock = {
"type": "tool_call",
"id": block["call_id"],
"id": block.get("call_id", ""),
}
if "id" in block:
new_block["item_id"] = block["id"]
for extra_key in ("arguments", "name"):
for extra_key in ("arguments", "name", "index"):
if extra_key in block:
new_block[extra_key] = block[extra_key]
yield new_block

View File

@ -3793,6 +3793,24 @@ def _construct_lc_result_from_responses_api(
message = _convert_to_v03_ai_message(message)
elif output_version == "v1":
message = _convert_to_v1_from_responses(message)
if response.tools and any(
tool.type == "image_generation" for tool in response.tools
):
# Get mime_time from tool definition and add to image generations
# if missing (primarily for tracing purposes).
image_generation_call = next(
tool for tool in response.tools if tool.type == "image_generation"
)
if image_generation_call.output_format:
mime_type = f"image/{image_generation_call.output_format}"
for block in message.content:
# OK to mutate output message
if (
block.get("type") == "image"
and block["source_type"] == "base64"
and "mime_type" not in block
):
block["mime_type"] = mime_type
else:
pass
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -569,10 +569,14 @@ def test_mcp_builtin_zdr() -> None:
_ = llm_with_tools.invoke([input_message, full, approval_message])
@pytest.mark.vcr()
def test_image_generation_streaming() -> None:
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_image_generation_streaming(output_version: str) -> None:
"""Test image generation streaming."""
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
llm = ChatOpenAI(
model="gpt-4.1", use_responses_api=True, output_version=output_version
)
tool = {
"type": "image_generation",
# For testing purposes let's keep the quality low, so the test runs faster.
@ -619,15 +623,35 @@ def test_image_generation_streaming() -> None:
# At the moment, the streaming API does not pick up annotations fully.
# So the following check is commented out.
# _check_response(complete_ai_message)
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
if output_version == "v0":
assert complete_ai_message.additional_kwargs["tool_outputs"]
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
elif output_version == "responses/v1":
tool_output = next(
block
for block in complete_ai_message.content
if block["type"] == "image_generation_call"
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
# v1
standard_keys = {"type", "source_type", "data", "id", "status", "index"}
tool_output = next(
block for block in complete_ai_message.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
@pytest.mark.vcr()
def test_image_generation_multi_turn() -> None:
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_image_generation_multi_turn(output_version: str) -> None:
"""Test multi-turn editing of image generation by passing in history."""
# Test multi-turn
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
llm = ChatOpenAI(
model="gpt-4.1", use_responses_api=True, output_version=output_version
)
# Test invocation
tool = {
"type": "image_generation",
@ -644,9 +668,37 @@ def test_image_generation_multi_turn() -> None:
]
ai_message = llm_with_tools.invoke(chat_history)
_check_response(ai_message)
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
# Example tool output for an image
expected_keys = {
"id",
"background",
"output_format",
"quality",
"result",
"revised_prompt",
"size",
"status",
"type",
}
if output_version == "v0":
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
elif output_version == "responses/v1":
tool_output = next(
block
for block in ai_message.content
if block["type"] == "image_generation_call"
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
standard_keys = {"type", "source_type", "data", "id", "status"}
tool_output = next(
block for block in ai_message.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
# Example tool output for an image (v0)
# {
# "background": "opaque",
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
@ -662,20 +714,6 @@ def test_image_generation_multi_turn() -> None:
# "result": # base64 encode image data
# }
expected_keys = {
"id",
"background",
"output_format",
"quality",
"result",
"revised_prompt",
"size",
"status",
"type",
}
assert set(tool_output.keys()).issubset(expected_keys)
chat_history.extend(
[
# AI message with tool output
@ -693,5 +731,20 @@ def test_image_generation_multi_turn() -> None:
ai_message2 = llm_with_tools.invoke(chat_history)
_check_response(ai_message2)
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
assert set(tool_output2.keys()).issubset(expected_keys)
if output_version == "v0":
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
assert set(tool_output.keys()).issubset(expected_keys)
elif output_version == "responses/v1":
tool_output = next(
block
for block in ai_message2.content
if block["type"] == "image_generation_call"
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
standard_keys = {"type", "source_type", "data", "id", "status"}
tool_output = next(
block for block in ai_message2.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())