mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
openai: support web search and code interpreter content blocks
This commit is contained in:
parent
b1a02f971b
commit
e1f034c795
@ -67,8 +67,8 @@ formats. The functions are used internally by ChatOpenAI.
|
||||
""" # noqa: E501
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union, cast
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Literal, Union, cast
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
|
||||
|
||||
@ -391,7 +391,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
||||
elif block_type == "image_generation_call" and (
|
||||
result := block.get("result")
|
||||
):
|
||||
new_block = {"type": "image", "source_type": "base64", "data": result}
|
||||
new_block = {"type": "image", "base64": result}
|
||||
if output_format := block.get("output_format"):
|
||||
new_block["mime_type"] = f"image/{output_format}"
|
||||
for extra_key in (
|
||||
@ -417,6 +417,68 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
||||
new_block[extra_key] = block[extra_key]
|
||||
yield new_block
|
||||
|
||||
elif block_type == "web_search_call":
|
||||
web_search_call = {"type": "web_search_call", "id": block["id"]}
|
||||
if "index" in block:
|
||||
web_search_call["index"] = block["index"]
|
||||
if (
|
||||
"action" in block
|
||||
and isinstance(block["action"], dict)
|
||||
and block["action"].get("type") == "search"
|
||||
and "query" in block["action"]
|
||||
):
|
||||
web_search_call["query"] = block["action"]["query"]
|
||||
for key in block:
|
||||
if key not in ("type", "id"):
|
||||
web_search_call[key] = block[key]
|
||||
|
||||
web_search_result = {"type": "web_search_result", "id": block["id"]}
|
||||
if "index" in block:
|
||||
web_search_result["index"] = block["index"] + 1
|
||||
yield web_search_call
|
||||
yield web_search_result
|
||||
|
||||
elif block_type == "code_interpreter_call":
|
||||
code_interpreter_call = {
|
||||
"type": "code_interpreter_call",
|
||||
"id": block["id"],
|
||||
}
|
||||
if "code" in block:
|
||||
code_interpreter_call["code"] = block["code"]
|
||||
if "container_id" in block:
|
||||
code_interpreter_call["container_id"] = block["container_id"]
|
||||
if "index" in block:
|
||||
code_interpreter_call["index"] = block["index"]
|
||||
|
||||
code_interpreter_result = {
|
||||
"type": "code_interpreter_result",
|
||||
"id": block["id"],
|
||||
}
|
||||
if "outputs" in block:
|
||||
code_interpreter_result["outputs"] = block["outputs"]
|
||||
for output in block["outputs"]:
|
||||
if (
|
||||
isinstance(output, dict)
|
||||
and (output_type := output.get("type"))
|
||||
and output_type == "logs"
|
||||
):
|
||||
if "output" not in code_interpreter_result:
|
||||
code_interpreter_result["output"] = []
|
||||
code_interpreter_result["output"].append(
|
||||
{
|
||||
"type": "code_interpreter_output",
|
||||
"stdout": output.get("logs", ""),
|
||||
}
|
||||
)
|
||||
|
||||
if "status" in block:
|
||||
code_interpreter_result["status"] = block["status"]
|
||||
if "index" in block:
|
||||
code_interpreter_result["index"] = block["index"] + 1
|
||||
|
||||
yield code_interpreter_call
|
||||
yield code_interpreter_result
|
||||
|
||||
else:
|
||||
new_block = {"type": "non_standard", "value": block}
|
||||
if "index" in new_block["value"]:
|
||||
@ -496,6 +558,69 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str
|
||||
yield merged
|
||||
|
||||
|
||||
def _consolidate_calls(
|
||||
items: Iterable[dict[str, Any]],
|
||||
call_name: Literal["web_search_call", "code_interpreter_call"],
|
||||
result_name: Literal["web_search_result", "code_interpreter_result"],
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Generator that walks through *items* and, whenever it meets the pair
|
||||
|
||||
{"type": "web_search_call", "id": X, ...}
|
||||
{"type": "web_search_result", "id": X}
|
||||
|
||||
merges them into
|
||||
|
||||
{"id": X,
|
||||
"action": …,
|
||||
"status": …,
|
||||
"type": "web_search_call"}
|
||||
|
||||
keeping every other element untouched.
|
||||
"""
|
||||
items = iter(items) # make sure we have a true iterator
|
||||
for current in items:
|
||||
# Only a call can start a pair worth collapsing
|
||||
if current.get("type") != call_name:
|
||||
yield current
|
||||
continue
|
||||
|
||||
try:
|
||||
nxt = next(items) # look-ahead one element
|
||||
except StopIteration: # no “result” – just yield the call back
|
||||
yield current
|
||||
break
|
||||
|
||||
# If this really is the matching “result” – collapse
|
||||
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
|
||||
if call_name == "web_search_call":
|
||||
collapsed = {
|
||||
"id": current["id"],
|
||||
"status": current["status"],
|
||||
"type": "web_search_call",
|
||||
}
|
||||
if "action" in current:
|
||||
collapsed["action"] = current["action"]
|
||||
|
||||
if call_name == "code_interpreter_call":
|
||||
collapsed = {"id": current["id"]}
|
||||
for key in ("code", "container_id"):
|
||||
if key in current:
|
||||
collapsed[key] = current[key]
|
||||
|
||||
for key in ("outputs", "status"):
|
||||
if key in nxt:
|
||||
collapsed[key] = nxt[key]
|
||||
collapsed["type"] = "code_interpreter_call"
|
||||
|
||||
yield collapsed
|
||||
|
||||
else:
|
||||
# Not a matching pair – emit both, in original order
|
||||
yield current
|
||||
yield nxt
|
||||
|
||||
|
||||
def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
@ -530,9 +655,9 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
||||
elif (
|
||||
is_data_content_block(block)
|
||||
and block["type"] == "image"
|
||||
and block["source_type"] == "base64"
|
||||
and "base64" in block
|
||||
):
|
||||
new_block = {"type": "image_generation_call", "result": block["data"]}
|
||||
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
||||
for extra_key in ("id", "status"):
|
||||
if extra_key in block:
|
||||
new_block[extra_key] = block[extra_key]
|
||||
@ -545,5 +670,13 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
||||
new_content.append(block)
|
||||
|
||||
new_content = list(_implode_reasoning_blocks(new_content))
|
||||
new_content = list(
|
||||
_consolidate_calls(new_content, "web_search_call", "web_search_result")
|
||||
)
|
||||
new_content = list(
|
||||
_consolidate_calls(
|
||||
new_content, "code_interpreter_call", "code_interpreter_result"
|
||||
)
|
||||
)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
@ -3803,11 +3803,11 @@ def _construct_lc_result_from_responses_api(
|
||||
)
|
||||
if image_generation_call.output_format:
|
||||
mime_type = f"image/{image_generation_call.output_format}"
|
||||
for block in message.beta_content: # type: ignore[assignment]
|
||||
for block in message.content:
|
||||
# OK to mutate output message
|
||||
if (
|
||||
block.get("type") == "image"
|
||||
and block["source_type"] == "base64"
|
||||
and "base64" in block
|
||||
and "mime_type" not in block
|
||||
):
|
||||
block["mime_type"] = mime_type
|
||||
@ -4051,6 +4051,10 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
)
|
||||
elif output_version == "v1":
|
||||
message = cast(AIMessageChunk, _convert_to_v1_from_responses(message))
|
||||
for block in message.content:
|
||||
if block.get("index", -1) > current_index:
|
||||
# blocks were added for v1
|
||||
current_index = block["index"]
|
||||
else:
|
||||
pass
|
||||
return (
|
||||
|
@ -115,7 +115,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
||||
if output_version == "responses/v1":
|
||||
assert block_types == ["web_search_call", "text"]
|
||||
else:
|
||||
assert block_types == ["non_standard", "text"]
|
||||
assert block_types == ["web_search_call", "web_search_result", "text"]
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@ -489,11 +489,17 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
|
||||
else:
|
||||
# v1
|
||||
tool_outputs = [
|
||||
item["value"]
|
||||
for item in response.beta_content
|
||||
if item["type"] == "non_standard"
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
assert tool_outputs[0]["type"] == "code_interpreter_call"
|
||||
code_interpreter_result = next(
|
||||
item
|
||||
for item in response.content
|
||||
if item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert tool_outputs
|
||||
assert code_interpreter_result
|
||||
assert len(tool_outputs) == 1
|
||||
|
||||
# Test streaming
|
||||
@ -521,12 +527,16 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
else:
|
||||
tool_outputs = [
|
||||
item["value"]
|
||||
for item in response.beta_content
|
||||
if item["type"] == "non_standard"
|
||||
]
|
||||
assert tool_outputs[0]["type"] == "code_interpreter_call"
|
||||
code_interpreter_call = next(
|
||||
item for item in response.content if item["type"] == "code_interpreter_call"
|
||||
)
|
||||
code_interpreter_result = next(
|
||||
item
|
||||
for item in response.content
|
||||
if item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert code_interpreter_call
|
||||
assert code_interpreter_result
|
||||
assert tool_outputs
|
||||
|
||||
# Test we can pass back in
|
||||
@ -689,11 +699,9 @@ def test_image_generation_streaming(output_version: str) -> None:
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
else:
|
||||
# v1
|
||||
standard_keys = {"type", "source_type", "data", "id", "status", "index"}
|
||||
standard_keys = {"type", "base64", "id", "status", "index"}
|
||||
tool_output = next(
|
||||
block
|
||||
for block in complete_ai_message.beta_content
|
||||
if block["type"] == "image"
|
||||
block for block in complete_ai_message.content if block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
||||
@ -748,9 +756,9 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
||||
)
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
else:
|
||||
standard_keys = {"type", "source_type", "data", "id", "status"}
|
||||
standard_keys = {"type", "base64", "id", "status"}
|
||||
tool_output = next(
|
||||
block for block in ai_message.beta_content if block["type"] == "image"
|
||||
block for block in ai_message.content if block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
||||
@ -800,8 +808,8 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
||||
)
|
||||
assert set(tool_output.keys()).issubset(expected_keys)
|
||||
else:
|
||||
standard_keys = {"type", "source_type", "data", "id", "status"}
|
||||
standard_keys = {"type", "base64", "id", "status"}
|
||||
tool_output = next(
|
||||
block for block in ai_message2.beta_content if block["type"] == "image"
|
||||
block for block in ai_message2.content if block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
Loading…
Reference in New Issue
Block a user