mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
openai[minor]: add image generation to responses api (#31424)
Does not support partial images during generation at the moment. Before doing that I'd like to figure out how to specify the aggregation logic without requiring changes in core. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
9a78246d29
commit
17f34baa88
@ -118,6 +118,15 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
|||||||
|
|
||||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||||
|
|
||||||
|
WellKnownTools = (
|
||||||
|
"file_search",
|
||||||
|
"web_search_preview",
|
||||||
|
"computer_use_preview",
|
||||||
|
"code_interpreter",
|
||||||
|
"mcp",
|
||||||
|
"image_generation",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
"""Convert a dictionary to a LangChain message.
|
"""Convert a dictionary to a LangChain message.
|
||||||
@ -1487,13 +1496,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": tool_choice},
|
"function": {"name": tool_choice},
|
||||||
}
|
}
|
||||||
elif tool_choice in (
|
elif tool_choice in WellKnownTools:
|
||||||
"file_search",
|
|
||||||
"web_search_preview",
|
|
||||||
"computer_use_preview",
|
|
||||||
"code_interpreter",
|
|
||||||
"mcp",
|
|
||||||
):
|
|
||||||
tool_choice = {"type": tool_choice}
|
tool_choice = {"type": tool_choice}
|
||||||
# 'any' is not natively supported by OpenAI API.
|
# 'any' is not natively supported by OpenAI API.
|
||||||
# We support 'any' since other models use this instead of 'required'.
|
# We support 'any' since other models use this instead of 'required'.
|
||||||
@ -3050,6 +3053,13 @@ def _construct_responses_api_payload(
|
|||||||
new_tools.append({"type": "function", **tool["function"]})
|
new_tools.append({"type": "function", **tool["function"]})
|
||||||
else:
|
else:
|
||||||
new_tools.append(tool)
|
new_tools.append(tool)
|
||||||
|
|
||||||
|
if tool["type"] == "image_generation" and "partial_images" in tool:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Partial image generation is not yet supported "
|
||||||
|
"via the LangChain ChatOpenAI client. Please "
|
||||||
|
"drop the 'partial_images' key from the image_generation tool."
|
||||||
|
)
|
||||||
payload["tools"] = new_tools
|
payload["tools"] = new_tools
|
||||||
if tool_choice := payload.pop("tool_choice", None):
|
if tool_choice := payload.pop("tool_choice", None):
|
||||||
# chat api: {"type": "function", "function": {"name": "..."}}
|
# chat api: {"type": "function", "function": {"name": "..."}}
|
||||||
@ -3139,6 +3149,7 @@ def _pop_summary_index_from_reasoning(reasoning: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||||
|
"""Construct the input for the OpenAI Responses API."""
|
||||||
input_ = []
|
input_ = []
|
||||||
for lc_msg in messages:
|
for lc_msg in messages:
|
||||||
msg = _convert_message_to_dict(lc_msg)
|
msg = _convert_message_to_dict(lc_msg)
|
||||||
@ -3191,6 +3202,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
|||||||
computer_calls = []
|
computer_calls = []
|
||||||
code_interpreter_calls = []
|
code_interpreter_calls = []
|
||||||
mcp_calls = []
|
mcp_calls = []
|
||||||
|
image_generation_calls = []
|
||||||
tool_outputs = lc_msg.additional_kwargs.get("tool_outputs", [])
|
tool_outputs = lc_msg.additional_kwargs.get("tool_outputs", [])
|
||||||
for tool_output in tool_outputs:
|
for tool_output in tool_outputs:
|
||||||
if tool_output.get("type") == "computer_call":
|
if tool_output.get("type") == "computer_call":
|
||||||
@ -3199,10 +3211,22 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
|||||||
code_interpreter_calls.append(tool_output)
|
code_interpreter_calls.append(tool_output)
|
||||||
elif tool_output.get("type") == "mcp_call":
|
elif tool_output.get("type") == "mcp_call":
|
||||||
mcp_calls.append(tool_output)
|
mcp_calls.append(tool_output)
|
||||||
|
elif tool_output.get("type") == "image_generation_call":
|
||||||
|
image_generation_calls.append(tool_output)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
input_.extend(code_interpreter_calls)
|
input_.extend(code_interpreter_calls)
|
||||||
input_.extend(mcp_calls)
|
input_.extend(mcp_calls)
|
||||||
|
|
||||||
|
# A previous image generation call can be referenced by ID
|
||||||
|
|
||||||
|
input_.extend(
|
||||||
|
[
|
||||||
|
{"type": "image_generation_call", "id": image_generation_call["id"]}
|
||||||
|
for image_generation_call in image_generation_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
msg["content"] = msg.get("content") or []
|
msg["content"] = msg.get("content") or []
|
||||||
if lc_msg.additional_kwargs.get("refusal"):
|
if lc_msg.additional_kwargs.get("refusal"):
|
||||||
if isinstance(msg["content"], str):
|
if isinstance(msg["content"], str):
|
||||||
@ -3489,6 +3513,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
|||||||
"mcp_call",
|
"mcp_call",
|
||||||
"mcp_list_tools",
|
"mcp_list_tools",
|
||||||
"mcp_approval_request",
|
"mcp_approval_request",
|
||||||
|
"image_generation_call",
|
||||||
):
|
):
|
||||||
additional_kwargs["tool_outputs"] = [
|
additional_kwargs["tool_outputs"] = [
|
||||||
chunk.item.model_dump(exclude_none=True, mode="json")
|
chunk.item.model_dump(exclude_none=True, mode="json")
|
||||||
@ -3516,6 +3541,9 @@ def _convert_responses_chunk_to_generation_chunk(
|
|||||||
{"index": chunk.summary_index, "type": "summary_text", "text": ""}
|
{"index": chunk.summary_index, "type": "summary_text", "text": ""}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
elif chunk.type == "response.image_generation_call.partial_image":
|
||||||
|
# Partial images are not supported yet.
|
||||||
|
pass
|
||||||
elif chunk.type == "response.reasoning_summary_text.delta":
|
elif chunk.type == "response.reasoning_summary_text.delta":
|
||||||
additional_kwargs["reasoning"] = {
|
additional_kwargs["reasoning"] = {
|
||||||
"summary": [
|
"summary": [
|
||||||
|
@ -7,7 +7,7 @@ authors = []
|
|||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"langchain-core<1.0.0,>=0.3.61",
|
"langchain-core<1.0.0,>=0.3.63",
|
||||||
"openai<2.0.0,>=1.68.2",
|
"openai<2.0.0,>=1.68.2",
|
||||||
"tiktoken<1,>=0.7",
|
"tiktoken<1,>=0.7",
|
||||||
]
|
]
|
||||||
|
Binary file not shown.
Binary file not shown.
@ -12,6 +12,7 @@ from langchain_core.messages import (
|
|||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
MessageLikeRepresentation,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
@ -452,3 +453,130 @@ def test_mcp_builtin() -> None:
|
|||||||
_ = llm_with_tools.invoke(
|
_ = llm_with_tools.invoke(
|
||||||
[approval_message], previous_response_id=response.response_metadata["id"]
|
[approval_message], previous_response_id=response.response_metadata["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
def test_image_generation_streaming() -> None:
|
||||||
|
"""Test image generation streaming."""
|
||||||
|
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
|
||||||
|
tool = {
|
||||||
|
"type": "image_generation",
|
||||||
|
# For testing purposes let's keep the quality low, so the test runs faster.
|
||||||
|
"quality": "low",
|
||||||
|
"output_format": "jpeg",
|
||||||
|
"output_compression": 100,
|
||||||
|
"size": "1024x1024",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Example tool output for an image
|
||||||
|
# {
|
||||||
|
# "background": "opaque",
|
||||||
|
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
|
||||||
|
# "output_format": "png",
|
||||||
|
# "quality": "high",
|
||||||
|
# "revised_prompt": "A fluffy, fuzzy cat sitting calmly, with soft fur, bright "
|
||||||
|
# "eyes, and a cute, friendly expression. The background is "
|
||||||
|
# "simple and light to emphasize the cat's texture and "
|
||||||
|
# "fluffiness.",
|
||||||
|
# "size": "1024x1024",
|
||||||
|
# "status": "completed",
|
||||||
|
# "type": "image_generation_call",
|
||||||
|
# "result": # base64 encode image data
|
||||||
|
# }
|
||||||
|
|
||||||
|
expected_keys = {
|
||||||
|
"id",
|
||||||
|
"background",
|
||||||
|
"output_format",
|
||||||
|
"quality",
|
||||||
|
"result",
|
||||||
|
"revised_prompt",
|
||||||
|
"size",
|
||||||
|
"status",
|
||||||
|
"type",
|
||||||
|
}
|
||||||
|
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]):
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
complete_ai_message = cast(AIMessageChunk, full)
|
||||||
|
# At the moment, the streaming API does not pick up annotations fully.
|
||||||
|
# So the following check is commented out.
|
||||||
|
# _check_response(complete_ai_message)
|
||||||
|
tool_output = complete_ai_message.additional_kwargs["tool_outputs"][0]
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
def test_image_generation_multi_turn() -> None:
|
||||||
|
"""Test multi-turn editing of image generation by passing in history."""
|
||||||
|
# Test multi-turn
|
||||||
|
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
|
||||||
|
# Test invocation
|
||||||
|
tool = {
|
||||||
|
"type": "image_generation",
|
||||||
|
# For testing purposes let's keep the quality low, so the test runs faster.
|
||||||
|
"quality": "low",
|
||||||
|
"output_format": "jpeg",
|
||||||
|
"output_compression": 100,
|
||||||
|
"size": "1024x1024",
|
||||||
|
}
|
||||||
|
llm_with_tools = llm.bind_tools([tool])
|
||||||
|
|
||||||
|
chat_history: list[MessageLikeRepresentation] = [
|
||||||
|
{"role": "user", "content": "Draw a random short word in green font."}
|
||||||
|
]
|
||||||
|
ai_message = llm_with_tools.invoke(chat_history)
|
||||||
|
_check_response(ai_message)
|
||||||
|
tool_output = ai_message.additional_kwargs["tool_outputs"][0]
|
||||||
|
|
||||||
|
# Example tool output for an image
|
||||||
|
# {
|
||||||
|
# "background": "opaque",
|
||||||
|
# "id": "ig_683716a8ddf0819888572b20621c7ae4029ec8c11f8dacf8",
|
||||||
|
# "output_format": "png",
|
||||||
|
# "quality": "high",
|
||||||
|
# "revised_prompt": "A fluffy, fuzzy cat sitting calmly, with soft fur, bright "
|
||||||
|
# "eyes, and a cute, friendly expression. The background is "
|
||||||
|
# "simple and light to emphasize the cat's texture and "
|
||||||
|
# "fluffiness.",
|
||||||
|
# "size": "1024x1024",
|
||||||
|
# "status": "completed",
|
||||||
|
# "type": "image_generation_call",
|
||||||
|
# "result": # base64 encode image data
|
||||||
|
# }
|
||||||
|
|
||||||
|
expected_keys = {
|
||||||
|
"id",
|
||||||
|
"background",
|
||||||
|
"output_format",
|
||||||
|
"quality",
|
||||||
|
"result",
|
||||||
|
"revised_prompt",
|
||||||
|
"size",
|
||||||
|
"status",
|
||||||
|
"type",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert set(tool_output.keys()).issubset(expected_keys)
|
||||||
|
|
||||||
|
chat_history.extend(
|
||||||
|
[
|
||||||
|
# AI message with tool output
|
||||||
|
ai_message,
|
||||||
|
# New request
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Now, change the font to blue. Keep the word and everything else "
|
||||||
|
"the same."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||||
|
_check_response(ai_message2)
|
||||||
|
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||||
|
assert set(tool_output2.keys()).issubset(expected_keys)
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user