diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index e6dff6336b0..47d73106069 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -67,8 +67,8 @@ formats. The functions are used internally by ChatOpenAI. """ # noqa: E501 import json -from collections.abc import Iterable -from typing import Any, Union, cast +from collections.abc import Iterable, Iterator +from typing import Any, Literal, Union, cast from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block @@ -391,7 +391,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: elif block_type == "image_generation_call" and ( result := block.get("result") ): - new_block = {"type": "image", "source_type": "base64", "data": result} + new_block = {"type": "image", "base64": result} if output_format := block.get("output_format"): new_block["mime_type"] = f"image/{output_format}" for extra_key in ( @@ -417,6 +417,68 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: new_block[extra_key] = block[extra_key] yield new_block + elif block_type == "web_search_call": + web_search_call = {"type": "web_search_call", "id": block["id"]} + if "index" in block: + web_search_call["index"] = block["index"] + if ( + "action" in block + and isinstance(block["action"], dict) + and block["action"].get("type") == "search" + and "query" in block["action"] + ): + web_search_call["query"] = block["action"]["query"] + for key in block: + if key not in ("type", "id"): + web_search_call[key] = block[key] + + web_search_result = {"type": "web_search_result", "id": block["id"]} + if "index" in block: + web_search_result["index"] = block["index"] + 1 + yield web_search_call + yield web_search_result + + elif block_type == "code_interpreter_call": + code_interpreter_call = { + "type": "code_interpreter_call", + "id": block["id"], + } + if "code" in block: + code_interpreter_call["code"] = block["code"] + if "container_id" in block: + code_interpreter_call["container_id"] = block["container_id"] + if "index" in block: + code_interpreter_call["index"] = block["index"] + + code_interpreter_result = { + "type": "code_interpreter_result", + "id": block["id"], + } + if "outputs" in block: + code_interpreter_result["outputs"] = block["outputs"] + for output in block["outputs"]: + if ( + isinstance(output, dict) + and (output_type := output.get("type")) + and output_type == "logs" + ): + if "output" not in code_interpreter_result: + code_interpreter_result["output"] = [] + code_interpreter_result["output"].append( + { + "type": "code_interpreter_output", + "stdout": output.get("logs", ""), + } + ) + + if "status" in block: + code_interpreter_result["status"] = block["status"] + if "index" in block: + code_interpreter_result["index"] = block["index"] + 1 + + yield code_interpreter_call + yield code_interpreter_result + else: new_block = {"type": "non_standard", "value": block} if "index" in new_block["value"]: @@ -496,6 +558,69 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str yield merged +def _consolidate_calls( + items: Iterable[dict[str, Any]], + call_name: Literal["web_search_call", "code_interpreter_call"], + result_name: Literal["web_search_result", "code_interpreter_result"], +) -> Iterator[dict[str, Any]]: + """ + Generator that walks through *items* and, whenever it meets the pair + + {"type": "web_search_call", "id": X, ...} + {"type": "web_search_result", "id": X} + + merges them into + + {"id": X, + "action": …, + "status": …, + "type": "web_search_call"} + + keeping every other element untouched. + """ + items = iter(items) # make sure we have a true iterator + for current in items: + # Only a call can start a pair worth collapsing + if current.get("type") != call_name: + yield current + continue + + try: + nxt = next(items) # look-ahead one element + except StopIteration: # no “result” – just yield the call back + yield current + break + + # If this really is the matching “result” – collapse + if nxt.get("type") == result_name and nxt.get("id") == current.get("id"): + if call_name == "web_search_call": + collapsed = { + "id": current["id"], + "status": current["status"], + "type": "web_search_call", + } + if "action" in current: + collapsed["action"] = current["action"] + + if call_name == "code_interpreter_call": + collapsed = {"id": current["id"]} + for key in ("code", "container_id"): + if key in current: + collapsed[key] = current[key] + + for key in ("outputs", "status"): + if key in nxt: + collapsed[key] = nxt[key] + collapsed["type"] = "code_interpreter_call" + + yield collapsed + + else: + # Not a matching pair – emit both, in original order + yield current + yield nxt + + def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: if not isinstance(message.content, list): return message @@ -530,9 +655,9 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: elif ( is_data_content_block(block) and block["type"] == "image" - and block["source_type"] == "base64" + and "base64" in block ): - new_block = {"type": "image_generation_call", "result": block["data"]} + new_block = {"type": "image_generation_call", "result": block["base64"]} for extra_key in ("id", "status"): if extra_key in block: new_block[extra_key] = block[extra_key] @@ -545,5 +670,13 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: new_content.append(block) new_content = list(_implode_reasoning_blocks(new_content)) + new_content = list( + _consolidate_calls(new_content, "web_search_call", "web_search_result") + ) + new_content = list( + _consolidate_calls( + new_content, "code_interpreter_call", "code_interpreter_result" + ) + ) return message.model_copy(update={"content": new_content}) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 39a2a6b2d91..ab094c1e943 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -3803,11 +3803,11 @@ def _construct_lc_result_from_responses_api( ) if image_generation_call.output_format: mime_type = f"image/{image_generation_call.output_format}" - for block in message.beta_content: # type: ignore[assignment] + for block in message.content: # OK to mutate output message if ( block.get("type") == "image" - and block["source_type"] == "base64" + and "base64" in block and "mime_type" not in block ): block["mime_type"] = mime_type @@ -4051,6 +4051,10 @@ def _convert_responses_chunk_to_generation_chunk( ) elif output_version == "v1": message = cast(AIMessageChunk, _convert_to_v1_from_responses(message)) + for block in message.content: + if block.get("index", -1) > current_index: + # blocks were added for v1 + current_index = block["index"] else: pass return ( diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index a20c7a46113..525b8b292cc 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -115,7 +115,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None: if output_version == "responses/v1": assert block_types == ["web_search_call", "text"] else: - assert block_types == ["non_standard", "text"] + assert block_types == ["web_search_call", "web_search_result", "text"] @pytest.mark.flaky(retries=3, delay=1) @@ -489,11 +489,17 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - else: # v1 tool_outputs = [ - item["value"] - for item in response.beta_content - if item["type"] == "non_standard" + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" ] - assert tool_outputs[0]["type"] == "code_interpreter_call" + code_interpreter_result = next( + item + for item in response.content + if item["type"] == "code_interpreter_result" + ) + assert tool_outputs + assert code_interpreter_result assert len(tool_outputs) == 1 # Test streaming @@ -521,12 +527,16 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - if isinstance(item, dict) and item["type"] == "code_interpreter_call" ] else: - tool_outputs = [ - item["value"] - for item in response.beta_content - if item["type"] == "non_standard" - ] - assert tool_outputs[0]["type"] == "code_interpreter_call" + code_interpreter_call = next( + item for item in response.content if item["type"] == "code_interpreter_call" + ) + code_interpreter_result = next( + item + for item in response.content + if item["type"] == "code_interpreter_result" + ) + assert code_interpreter_call + assert code_interpreter_result assert tool_outputs # Test we can pass back in @@ -689,11 +699,9 @@ def test_image_generation_streaming(output_version: str) -> None: assert set(tool_output.keys()).issubset(expected_keys) else: # v1 - standard_keys = {"type", "source_type", "data", "id", "status", "index"} + standard_keys = {"type", "base64", "id", "status", "index"} tool_output = next( - block - for block in complete_ai_message.beta_content - if block["type"] == "image" + block for block in complete_ai_message.content if block["type"] == "image" ) assert set(standard_keys).issubset(tool_output.keys()) @@ -748,9 +756,9 @@ def test_image_generation_multi_turn(output_version: str) -> None: ) assert set(tool_output.keys()).issubset(expected_keys) else: - standard_keys = {"type", "source_type", "data", "id", "status"} + standard_keys = {"type", "base64", "id", "status"} tool_output = next( - block for block in ai_message.beta_content if block["type"] == "image" + block for block in ai_message.content if block["type"] == "image" ) assert set(standard_keys).issubset(tool_output.keys()) @@ -800,8 +808,8 @@ def test_image_generation_multi_turn(output_version: str) -> None: ) assert set(tool_output.keys()).issubset(expected_keys) else: - standard_keys = {"type", "source_type", "data", "id", "status"} + standard_keys = {"type", "base64", "id", "status"} tool_output = next( - block for block in ai_message2.beta_content if block["type"] == "image" + block for block in ai_message2.content if block["type"] == "image" ) assert set(standard_keys).issubset(tool_output.keys())