mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 13:55:03 +00:00
fix image generation
This commit is contained in:
parent
67fc58011a
commit
e928672306
@ -291,8 +291,8 @@ def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage:
|
|||||||
|
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
if id_ := tool_call.get("id"):
|
if id_ := tool_call.get("id"):
|
||||||
tool_callblock: ToolCallContentBlock = {"type": "tool_call", "id": id_}
|
tool_call_block: ToolCallContentBlock = {"type": "tool_call", "id": id_}
|
||||||
message.content.append(tool_callblock)
|
message.content.append(tool_call_block)
|
||||||
|
|
||||||
if "tool_calls" in message.additional_kwargs:
|
if "tool_calls" in message.additional_kwargs:
|
||||||
_ = message.additional_kwargs.pop("tool_calls")
|
_ = message.additional_kwargs.pop("tool_calls")
|
||||||
@ -413,7 +413,18 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
|||||||
"source_type": "base64",
|
"source_type": "base64",
|
||||||
"data": result,
|
"data": result,
|
||||||
}
|
}
|
||||||
for extra_key in ("id", "status"):
|
if output_format := block.get("output_format"):
|
||||||
|
new_block["mime_type"] = f"image/{output_format}"
|
||||||
|
for extra_key in (
|
||||||
|
"id",
|
||||||
|
"index",
|
||||||
|
"status",
|
||||||
|
"background",
|
||||||
|
"output_format",
|
||||||
|
"quality",
|
||||||
|
"revised_prompt",
|
||||||
|
"size",
|
||||||
|
):
|
||||||
if extra_key in block:
|
if extra_key in block:
|
||||||
new_block[extra_key] = block[extra_key]
|
new_block[extra_key] = block[extra_key]
|
||||||
yield new_block
|
yield new_block
|
||||||
@ -421,11 +432,11 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
|||||||
elif block_type == "function_call":
|
elif block_type == "function_call":
|
||||||
new_block: ToolCallContentBlock = {
|
new_block: ToolCallContentBlock = {
|
||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
"id": block["call_id"],
|
"id": block.get("call_id", ""),
|
||||||
}
|
}
|
||||||
if "id" in block:
|
if "id" in block:
|
||||||
new_block["item_id"] = block["id"]
|
new_block["item_id"] = block["id"]
|
||||||
for extra_key in ("arguments", "name"):
|
for extra_key in ("arguments", "name", "index"):
|
||||||
if extra_key in block:
|
if extra_key in block:
|
||||||
new_block[extra_key] = block[extra_key]
|
new_block[extra_key] = block[extra_key]
|
||||||
yield new_block
|
yield new_block
|
||||||
|
@ -3793,6 +3793,24 @@ def _construct_lc_result_from_responses_api(
|
|||||||
message = _convert_to_v03_ai_message(message)
|
message = _convert_to_v03_ai_message(message)
|
||||||
elif output_version == "v1":
|
elif output_version == "v1":
|
||||||
message = _convert_to_v1_from_responses(message)
|
message = _convert_to_v1_from_responses(message)
|
||||||
|
if response.tools and any(
|
||||||
|
tool.type == "image_generation" for tool in response.tools
|
||||||
|
):
|
||||||
|
# Get mime_time from tool definition and add to image generations
|
||||||
|
# if missing (primarily for tracing purposes).
|
||||||
|
image_generation_call = next(
|
||||||
|
tool for tool in response.tools if tool.type == "image_generation"
|
||||||
|
)
|
||||||
|
if image_generation_call.output_format:
|
||||||
|
mime_type = f"image/{image_generation_call.output_format}"
|
||||||
|
for block in message.content:
|
||||||
|
# OK to mutate output message
|
||||||
|
if (
|
||||||
|
block.get("type") == "image"
|
||||||
|
and block["source_type"] == "base64"
|
||||||
|
and "mime_type" not in block
|
||||||
|
):
|
||||||
|
block["mime_type"] = mime_type
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
@ -569,10 +569,14 @@ def test_mcp_builtin_zdr() -> None:
|
|||||||
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||||
def test_image_generation_streaming() -> None:
|
@pytest.mark.vcr
|
||||||
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||||
|
def test_image_generation_streaming(output_version: str) -> None:
|
||||||
"""Test image generation streaming."""
|
"""Test image generation streaming."""
|
||||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
|
llm = ChatOpenAI(
|
||||||
|
model="gpt-4.1", use_responses_api=True, output_version=output_version
|
||||||
|
)
|
||||||
tool = {
|
tool = {
|
||||||
"type": "image_generation",
|
"type": "image_generation",
|
||||||
# For testing purposes let's keep the quality low, so the test runs faster.
|
# For testing purposes let's keep the quality low, so the test runs faster.
|
||||||
@ -619,15 +623,35 @@ def test_image_generation_streaming() -> None:
|
|||||||
# At the moment, the streaming API does not pick up annotations fully.
|
# At the moment, the streaming API does not pick up annotations fully.
|
||||||
# So the following check is commented out.
|
# So the following check is commented out.
|
||||||
# _check_response(complete_ai_message)
|
# _check_response(complete_ai_message)
|
||||||
|
if output_version == "v0":
|
||||||
|
assert complete_ai_message.additional_kwargs["tool_outputs"]
|
||||||
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
|
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
|
||||||
assert set(tool_output.keys()).issubset(expected_keys)
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
elif output_version == "responses/v1":
|
||||||
|
tool_output = next(
|
||||||
|
block
|
||||||
|
for block in complete_ai_message.content
|
||||||
|
if block["type"] == "image_generation_call"
|
||||||
|
)
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
else:
|
||||||
|
# v1
|
||||||
|
standard_keys = {"type", "source_type", "data", "id", "status", "index"}
|
||||||
|
tool_output = next(
|
||||||
|
block for block in complete_ai_message.content if block["type"] == "image"
|
||||||
|
)
|
||||||
|
assert set(standard_keys).issubset(tool_output.keys())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||||
def test_image_generation_multi_turn() -> None:
|
@pytest.mark.vcr
|
||||||
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||||
|
def test_image_generation_multi_turn(output_version: str) -> None:
|
||||||
"""Test multi-turn editing of image generation by passing in history."""
|
"""Test multi-turn editing of image generation by passing in history."""
|
||||||
# Test multi-turn
|
# Test multi-turn
|
||||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
|
llm = ChatOpenAI(
|
||||||
|
model="gpt-4.1", use_responses_api=True, output_version=output_version
|
||||||
|
)
|
||||||
# Test invocation
|
# Test invocation
|
||||||
tool = {
|
tool = {
|
||||||
"type": "image_generation",
|
"type": "image_generation",
|
||||||
@ -644,9 +668,37 @@ def test_image_generation_multi_turn() -> None:
|
|||||||
]
|
]
|
||||||
ai_message = llm_with_tools.invoke(chat_history)
|
ai_message = llm_with_tools.invoke(chat_history)
|
||||||
_check_response(ai_message)
|
_check_response(ai_message)
|
||||||
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
|
|
||||||
|
|
||||||
# Example tool output for an image
|
expected_keys = {
|
||||||
|
"id",
|
||||||
|
"background",
|
||||||
|
"output_format",
|
||||||
|
"quality",
|
||||||
|
"result",
|
||||||
|
"revised_prompt",
|
||||||
|
"size",
|
||||||
|
"status",
|
||||||
|
"type",
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_version == "v0":
|
||||||
|
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
elif output_version == "responses/v1":
|
||||||
|
tool_output = next(
|
||||||
|
block
|
||||||
|
for block in ai_message.content
|
||||||
|
if block["type"] == "image_generation_call"
|
||||||
|
)
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
else:
|
||||||
|
standard_keys = {"type", "source_type", "data", "id", "status"}
|
||||||
|
tool_output = next(
|
||||||
|
block for block in ai_message.content if block["type"] == "image"
|
||||||
|
)
|
||||||
|
assert set(standard_keys).issubset(tool_output.keys())
|
||||||
|
|
||||||
|
# Example tool output for an image (v0)
|
||||||
# {
|
# {
|
||||||
# "background": "opaque",
|
# "background": "opaque",
|
||||||
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
|
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
|
||||||
@ -662,20 +714,6 @@ def test_image_generation_multi_turn() -> None:
|
|||||||
# "result": # base64 encode image data
|
# "result": # base64 encode image data
|
||||||
# }
|
# }
|
||||||
|
|
||||||
expected_keys = {
|
|
||||||
"id",
|
|
||||||
"background",
|
|
||||||
"output_format",
|
|
||||||
"quality",
|
|
||||||
"result",
|
|
||||||
"revised_prompt",
|
|
||||||
"size",
|
|
||||||
"status",
|
|
||||||
"type",
|
|
||||||
}
|
|
||||||
|
|
||||||
assert set(tool_output.keys()).issubset(expected_keys)
|
|
||||||
|
|
||||||
chat_history.extend(
|
chat_history.extend(
|
||||||
[
|
[
|
||||||
# AI message with tool output
|
# AI message with tool output
|
||||||
@ -693,5 +731,20 @@ def test_image_generation_multi_turn() -> None:
|
|||||||
|
|
||||||
ai_message2 = llm_with_tools.invoke(chat_history)
|
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||||
_check_response(ai_message2)
|
_check_response(ai_message2)
|
||||||
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
|
|
||||||
assert set(tool_output2.keys()).issubset(expected_keys)
|
if output_version == "v0":
|
||||||
|
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
elif output_version == "responses/v1":
|
||||||
|
tool_output = next(
|
||||||
|
block
|
||||||
|
for block in ai_message2.content
|
||||||
|
if block["type"] == "image_generation_call"
|
||||||
|
)
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
else:
|
||||||
|
standard_keys = {"type", "source_type", "data", "id", "status"}
|
||||||
|
tool_output = next(
|
||||||
|
block for block in ai_message2.content if block["type"] == "image"
|
||||||
|
)
|
||||||
|
assert set(standard_keys).issubset(tool_output.keys())
|
||||||
|
Loading…
Reference in New Issue
Block a user