openai: support web search and code interpreter content blocks

This commit is contained in:
Chester Curme 2025-07-22 16:58:43 -04:00
parent b1a02f971b
commit e1f034c795
3 changed files with 171 additions and 26 deletions

View File

@ -67,8 +67,8 @@ formats. The functions are used internally by ChatOpenAI.
""" # noqa: E501
import json
from collections.abc import Iterable
from typing import Any, Union, cast
from collections.abc import Iterable, Iterator
from typing import Any, Literal, Union, cast
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
@ -391,7 +391,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
elif block_type == "image_generation_call" and (
result := block.get("result")
):
new_block = {"type": "image", "source_type": "base64", "data": result}
new_block = {"type": "image", "base64": result}
if output_format := block.get("output_format"):
new_block["mime_type"] = f"image/{output_format}"
for extra_key in (
@ -417,6 +417,68 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
new_block[extra_key] = block[extra_key]
yield new_block
elif block_type == "web_search_call":
web_search_call = {"type": "web_search_call", "id": block["id"]}
if "index" in block:
web_search_call["index"] = block["index"]
if (
"action" in block
and isinstance(block["action"], dict)
and block["action"].get("type") == "search"
and "query" in block["action"]
):
web_search_call["query"] = block["action"]["query"]
for key in block:
if key not in ("type", "id"):
web_search_call[key] = block[key]
web_search_result = {"type": "web_search_result", "id": block["id"]}
if "index" in block:
web_search_result["index"] = block["index"] + 1
yield web_search_call
yield web_search_result
elif block_type == "code_interpreter_call":
code_interpreter_call = {
"type": "code_interpreter_call",
"id": block["id"],
}
if "code" in block:
code_interpreter_call["code"] = block["code"]
if "container_id" in block:
code_interpreter_call["container_id"] = block["container_id"]
if "index" in block:
code_interpreter_call["index"] = block["index"]
code_interpreter_result = {
"type": "code_interpreter_result",
"id": block["id"],
}
if "outputs" in block:
code_interpreter_result["outputs"] = block["outputs"]
for output in block["outputs"]:
if (
isinstance(output, dict)
and (output_type := output.get("type"))
and output_type == "logs"
):
if "output" not in code_interpreter_result:
code_interpreter_result["output"] = []
code_interpreter_result["output"].append(
{
"type": "code_interpreter_output",
"stdout": output.get("logs", ""),
}
)
if "status" in block:
code_interpreter_result["status"] = block["status"]
if "index" in block:
code_interpreter_result["index"] = block["index"] + 1
yield code_interpreter_call
yield code_interpreter_result
else:
new_block = {"type": "non_standard", "value": block}
if "index" in new_block["value"]:
@ -496,6 +558,69 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str
yield merged
def _consolidate_calls(
items: Iterable[dict[str, Any]],
call_name: Literal["web_search_call", "code_interpreter_call"],
result_name: Literal["web_search_result", "code_interpreter_result"],
) -> Iterator[dict[str, Any]]:
"""
Generator that walks through *items* and, whenever it meets the pair
{"type": "web_search_call", "id": X, ...}
{"type": "web_search_result", "id": X}
merges them into
{"id": X,
"action": ,
"status": ,
"type": "web_search_call"}
keeping every other element untouched.
"""
items = iter(items) # make sure we have a true iterator
for current in items:
# Only a call can start a pair worth collapsing
if current.get("type") != call_name:
yield current
continue
try:
nxt = next(items) # look-ahead one element
except StopIteration: # no “result” just yield the call back
yield current
break
# If this really is the matching “result” collapse
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
if call_name == "web_search_call":
collapsed = {
"id": current["id"],
"status": current["status"],
"type": "web_search_call",
}
if "action" in current:
collapsed["action"] = current["action"]
if call_name == "code_interpreter_call":
collapsed = {"id": current["id"]}
for key in ("code", "container_id"):
if key in current:
collapsed[key] = current[key]
for key in ("outputs", "status"):
if key in nxt:
collapsed[key] = nxt[key]
collapsed["type"] = "code_interpreter_call"
yield collapsed
else:
# Not a matching pair emit both, in original order
yield current
yield nxt
def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
if not isinstance(message.content, list):
return message
@ -530,9 +655,9 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
elif (
is_data_content_block(block)
and block["type"] == "image"
and block["source_type"] == "base64"
and "base64" in block
):
new_block = {"type": "image_generation_call", "result": block["data"]}
new_block = {"type": "image_generation_call", "result": block["base64"]}
for extra_key in ("id", "status"):
if extra_key in block:
new_block[extra_key] = block[extra_key]
@ -545,5 +670,13 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
new_content.append(block)
new_content = list(_implode_reasoning_blocks(new_content))
new_content = list(
_consolidate_calls(new_content, "web_search_call", "web_search_result")
)
new_content = list(
_consolidate_calls(
new_content, "code_interpreter_call", "code_interpreter_result"
)
)
return message.model_copy(update={"content": new_content})

View File

@ -3803,11 +3803,11 @@ def _construct_lc_result_from_responses_api(
)
if image_generation_call.output_format:
mime_type = f"image/{image_generation_call.output_format}"
for block in message.beta_content: # type: ignore[assignment]
for block in message.content:
# OK to mutate output message
if (
block.get("type") == "image"
and block["source_type"] == "base64"
and "base64" in block
and "mime_type" not in block
):
block["mime_type"] = mime_type
@ -4051,6 +4051,10 @@ def _convert_responses_chunk_to_generation_chunk(
)
elif output_version == "v1":
message = cast(AIMessageChunk, _convert_to_v1_from_responses(message))
for block in message.content:
if block.get("index", -1) > current_index:
# blocks were added for v1
current_index = block["index"]
else:
pass
return (

View File

@ -115,7 +115,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
if output_version == "responses/v1":
assert block_types == ["web_search_call", "text"]
else:
assert block_types == ["non_standard", "text"]
assert block_types == ["web_search_call", "web_search_result", "text"]
@pytest.mark.flaky(retries=3, delay=1)
@ -489,11 +489,17 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
else:
# v1
tool_outputs = [
item["value"]
for item in response.beta_content
if item["type"] == "non_standard"
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
]
assert tool_outputs[0]["type"] == "code_interpreter_call"
code_interpreter_result = next(
item
for item in response.content
if item["type"] == "code_interpreter_result"
)
assert tool_outputs
assert code_interpreter_result
assert len(tool_outputs) == 1
# Test streaming
@ -521,12 +527,16 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
]
else:
tool_outputs = [
item["value"]
for item in response.beta_content
if item["type"] == "non_standard"
]
assert tool_outputs[0]["type"] == "code_interpreter_call"
code_interpreter_call = next(
item for item in response.content if item["type"] == "code_interpreter_call"
)
code_interpreter_result = next(
item
for item in response.content
if item["type"] == "code_interpreter_result"
)
assert code_interpreter_call
assert code_interpreter_result
assert tool_outputs
# Test we can pass back in
@ -689,11 +699,9 @@ def test_image_generation_streaming(output_version: str) -> None:
assert set(tool_output.keys()).issubset(expected_keys)
else:
# v1
standard_keys = {"type", "source_type", "data", "id", "status", "index"}
standard_keys = {"type", "base64", "id", "status", "index"}
tool_output = next(
block
for block in complete_ai_message.beta_content
if block["type"] == "image"
block for block in complete_ai_message.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
@ -748,9 +756,9 @@ def test_image_generation_multi_turn(output_version: str) -> None:
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
standard_keys = {"type", "source_type", "data", "id", "status"}
standard_keys = {"type", "base64", "id", "status"}
tool_output = next(
block for block in ai_message.beta_content if block["type"] == "image"
block for block in ai_message.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
@ -800,8 +808,8 @@ def test_image_generation_multi_turn(output_version: str) -> None:
)
assert set(tool_output.keys()).issubset(expected_keys)
else:
standard_keys = {"type", "source_type", "data", "id", "status"}
standard_keys = {"type", "base64", "id", "status"}
tool_output = next(
block for block in ai_message2.beta_content if block["type"] == "image"
block for block in ai_message2.content if block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())