diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index 55c210dba09..4d314936033 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -113,6 +113,7 @@ def test_configurable() -> None: "openai_api_base": None, "openai_organization": None, "openai_proxy": None, + "output_version": "v0", "request_timeout": None, "max_retries": None, "presence_penalty": None, diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index a0db5df9b4c..68fa2c310cb 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -128,6 +128,8 @@ def _convert_to_v03_ai_message( else: new_content.append(block) message.content = new_content + if isinstance(message.id, str) and message.id.startswith("resp_"): + message.id = None else: pass @@ -137,13 +139,29 @@ def _convert_to_v03_ai_message( def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage: """Convert an old-style v0.3 AIMessage into the new content-block format.""" # Only update ChatOpenAI v0.3 AIMessages - if not ( + # TODO: structure provenance into AIMessage + is_chatopenai_v03 = ( isinstance(message.content, list) and all(isinstance(b, dict) for b in message.content) - ) or not any( - item in message.additional_kwargs - for item in ["reasoning", "tool_outputs", "refusal", _FUNCTION_CALL_IDS_MAP_KEY] - ): + ) and ( + any( + item in message.additional_kwargs + for item in [ + "reasoning", + "tool_outputs", + "refusal", + _FUNCTION_CALL_IDS_MAP_KEY, + ] + ) + or ( + isinstance(message.id, str) + and message.id.startswith("msg_") + and (response_id := message.response_metadata.get("id")) + and isinstance(response_id, str) + and response_id.startswith("resp_") + ) + ) + if not is_chatopenai_v03: return message content_order = [ diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index d2585b37fac..9b6f66aa831 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -649,6 +649,25 @@ class BaseChatOpenAI(BaseChatModel): .. versionadded:: 0.3.9 """ + output_version: Literal["v0", "responses/v1"] = "v0" + """Version of AIMessage output format to use. + + This field is used to roll-out new output formats for chat model AIMessages + in a backwards-compatible way. + + Supported values: + + - ``"v0"``: AIMessage format as of langchain-openai 0.3.x. + - ``"responses/v1"``: Formats Responses API output + items into AIMessage content blocks. + + Currently only impacts the Responses API. ``output_version="responses/v1"`` is + recommended. + + .. versionadded:: 0.3.25 + + """ + model_config = ConfigDict(populate_by_name=True) @model_validator(mode="before") @@ -903,6 +922,7 @@ class BaseChatOpenAI(BaseChatModel): schema=original_schema_obj, metadata=metadata, has_reasoning=has_reasoning, + output_version=self.output_version, ) if generation_chunk: if run_manager: @@ -957,6 +977,7 @@ class BaseChatOpenAI(BaseChatModel): schema=original_schema_obj, metadata=metadata, has_reasoning=has_reasoning, + output_version=self.output_version, ) if generation_chunk: if run_manager: @@ -1096,7 +1117,10 @@ class BaseChatOpenAI(BaseChatModel): else: response = self.root_client.responses.create(**payload) return _construct_lc_result_from_responses_api( - response, schema=original_schema_obj, metadata=generation_info + response, + schema=original_schema_obj, + metadata=generation_info, + output_version=self.output_version, ) elif self.include_response_headers: raw_response = self.client.with_raw_response.create(**payload) @@ -1109,6 +1133,8 @@ class BaseChatOpenAI(BaseChatModel): def _use_responses_api(self, payload: dict) -> bool: if isinstance(self.use_responses_api, bool): return self.use_responses_api + elif self.output_version == "responses/v1": + return True elif self.include is not None: return True elif self.reasoning is not None: @@ -1327,7 +1353,10 @@ class BaseChatOpenAI(BaseChatModel): else: response = await self.root_async_client.responses.create(**payload) return _construct_lc_result_from_responses_api( - response, schema=original_schema_obj, metadata=generation_info + response, + schema=original_schema_obj, + metadata=generation_info, + output_version=self.output_version, ) elif self.include_response_headers: raw_response = await self.async_client.with_raw_response.create(**payload) @@ -3540,6 +3569,7 @@ def _construct_lc_result_from_responses_api( response: Response, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None, + output_version: Literal["v0", "responses/v1"] = "v0", ) -> ChatResult: """Construct ChatResponse from OpenAI Response API response.""" if response.error: @@ -3676,7 +3706,10 @@ def _construct_lc_result_from_responses_api( tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) - message = _convert_to_v03_ai_message(message) + if output_version == "v0": + message = _convert_to_v03_ai_message(message) + else: + pass return ChatResult(generations=[ChatGeneration(message=message)]) @@ -3688,6 +3721,7 @@ def _convert_responses_chunk_to_generation_chunk( schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None, has_reasoning: bool = False, + output_version: Literal["v0", "responses/v1"] = "v0", ) -> tuple[int, int, int, Optional[ChatGenerationChunk]]: def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None: """Advance indexes tracked during streaming. @@ -3756,12 +3790,15 @@ def _convert_responses_chunk_to_generation_chunk( elif chunk.type == "response.output_text.done": content.append({"id": chunk.item_id, "index": current_index}) elif chunk.type == "response.created": - response_metadata["id"] = chunk.response.id + id = chunk.response.id + response_metadata["id"] = chunk.response.id # Backwards compatibility elif chunk.type == "response.completed": msg = cast( AIMessage, ( - _construct_lc_result_from_responses_api(chunk.response, schema=schema) + _construct_lc_result_from_responses_api( + chunk.response, schema=schema, output_version=output_version + ) .generations[0] .message ), @@ -3773,7 +3810,10 @@ def _convert_responses_chunk_to_generation_chunk( k: v for k, v in msg.response_metadata.items() if k != "id" } elif chunk.type == "response.output_item.added" and chunk.item.type == "message": - id = chunk.item.id + if output_version == "v0": + id = chunk.item.id + else: + pass elif ( chunk.type == "response.output_item.added" and chunk.item.type == "function_call" @@ -3868,9 +3908,13 @@ def _convert_responses_chunk_to_generation_chunk( additional_kwargs=additional_kwargs, id=id, ) - message = cast( - AIMessageChunk, _convert_to_v03_ai_message(message, has_reasoning=has_reasoning) - ) + if output_version == "v0": + message = cast( + AIMessageChunk, + _convert_to_v03_ai_message(message, has_reasoning=has_reasoning), + ) + else: + pass return ( current_index, current_output_index, diff --git a/libs/partners/openai/tests/cassettes/test_mcp_builtin_zdr.yaml.gz b/libs/partners/openai/tests/cassettes/test_mcp_builtin_zdr.yaml.gz new file mode 100644 index 00000000000..f6cf24dee10 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_mcp_builtin_zdr.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_reasoning.yaml.gz b/libs/partners/openai/tests/cassettes/test_reasoning.yaml.gz new file mode 100644 index 00000000000..d598966c99a Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_reasoning.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_stream_reasoning_summary.yaml.gz b/libs/partners/openai/tests/cassettes/test_stream_reasoning_summary.yaml.gz new file mode 100644 index 00000000000..ac89d7580cd Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_stream_reasoning_summary.yaml.gz differ diff --git a/libs/partners/openai/tests/cassettes/test_web_search.yaml.gz b/libs/partners/openai/tests/cassettes/test_web_search.yaml.gz index d63dc1f1668..e99f1c2e13a 100644 Binary files a/libs/partners/openai/tests/cassettes/test_web_search.yaml.gz and b/libs/partners/openai/tests/cassettes/test_web_search.yaml.gz differ diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index 622c83f0f7f..0e23d0e3f06 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -2,7 +2,7 @@ import json import os -from typing import Annotated, Any, Optional, cast +from typing import Annotated, Any, Literal, Optional, cast import openai import pytest @@ -50,15 +50,11 @@ def _check_response(response: Optional[BaseMessage]) -> None: assert response.usage_metadata["total_tokens"] > 0 assert response.response_metadata["model_name"] assert response.response_metadata["service_tier"] - for tool_output in response.additional_kwargs["tool_outputs"]: - assert tool_output["id"] - assert tool_output["status"] - assert tool_output["type"] @pytest.mark.vcr def test_web_search() -> None: - llm = ChatOpenAI(model=MODEL_NAME) + llm = ChatOpenAI(model=MODEL_NAME, output_version="responses/v1") first_response = llm.invoke( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -111,6 +107,11 @@ def test_web_search() -> None: ) _check_response(response) + for msg in [first_response, full, response]: + assert isinstance(msg, AIMessage) + block_types = [block["type"] for block in msg.content] # type: ignore[index] + assert block_types == ["web_search_call", "text"] + @pytest.mark.flaky(retries=3, delay=1) async def test_web_search_async() -> None: @@ -133,6 +134,12 @@ async def test_web_search_async() -> None: assert isinstance(full, AIMessageChunk) _check_response(full) + for msg in [response, full]: + assert msg.additional_kwargs["tool_outputs"] + assert len(msg.additional_kwargs["tool_outputs"]) == 1 + tool_output = msg.additional_kwargs["tool_outputs"][0] + assert tool_output["type"] == "web_search_call" + @pytest.mark.flaky(retries=3, delay=1) def test_function_calling() -> None: @@ -288,20 +295,32 @@ def test_function_calling_and_structured_output() -> None: assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"} -def test_reasoning() -> None: - llm = ChatOpenAI(model="o3-mini", use_responses_api=True) +@pytest.mark.default_cassette("test_reasoning.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1"]) +def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None: + llm = ChatOpenAI( + model="o4-mini", use_responses_api=True, output_version=output_version + ) response = llm.invoke("Hello", reasoning={"effort": "low"}) assert isinstance(response, AIMessage) - assert response.additional_kwargs["reasoning"] # Test init params + streaming - llm = ChatOpenAI(model="o3-mini", reasoning_effort="low", use_responses_api=True) + llm = ChatOpenAI( + model="o4-mini", reasoning={"effort": "low"}, output_version=output_version + ) full: Optional[BaseMessageChunk] = None for chunk in llm.stream("Hello"): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk assert isinstance(full, AIMessage) - assert full.additional_kwargs["reasoning"] + + for msg in [response, full]: + if output_version == "v0": + assert msg.additional_kwargs["reasoning"] + else: + block_types = [block["type"] for block in msg.content] + assert block_types == ["reasoning", "text"] def test_stateful_api() -> None: @@ -355,20 +374,37 @@ def test_file_search() -> None: _check_response(full) -def test_stream_reasoning_summary() -> None: +@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["v0", "responses/v1"]) +def test_stream_reasoning_summary( + output_version: Literal["v0", "responses/v1"], +) -> None: llm = ChatOpenAI( model="o4-mini", # Routes to Responses API if `reasoning` is set. reasoning={"effort": "medium", "summary": "auto"}, + output_version=output_version, ) - message_1 = {"role": "user", "content": "What is 3^3?"} + message_1 = { + "role": "user", + "content": "What was the third tallest buliding in the year 2000?", + } response_1: Optional[BaseMessageChunk] = None for chunk in llm.stream([message_1]): assert isinstance(chunk, AIMessageChunk) response_1 = chunk if response_1 is None else response_1 + chunk assert isinstance(response_1, AIMessageChunk) - reasoning = response_1.additional_kwargs["reasoning"] - assert set(reasoning.keys()) == {"id", "type", "summary"} + if output_version == "v0": + reasoning = response_1.additional_kwargs["reasoning"] + assert set(reasoning.keys()) == {"id", "type", "summary"} + else: + reasoning = next( + block + for block in response_1.content + if block["type"] == "reasoning" # type: ignore[index] + ) + assert set(reasoning.keys()) == {"id", "type", "summary", "index"} summary = reasoning["summary"] assert isinstance(summary, list) for block in summary: @@ -462,11 +498,11 @@ def test_mcp_builtin() -> None: ) -@pytest.mark.skip +@pytest.mark.vcr def test_mcp_builtin_zdr() -> None: llm = ChatOpenAI( model="o4-mini", - use_responses_api=True, + output_version="responses/v1", store=False, include=["reasoning.encrypted_content"], ) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr index 2060512958a..ddadd6fc09b 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr @@ -24,6 +24,7 @@ }), 'openai_api_type': 'azure', 'openai_api_version': '2021-10-01', + 'output_version': 'v0', 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr index e7307c6158f..1a74f4978a7 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr @@ -18,6 +18,7 @@ 'lc': 1, 'type': 'secret', }), + 'output_version': 'v0', 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr index 88a49a27502..10d1355af4e 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr @@ -18,6 +18,7 @@ 'lc': 1, 'type': 'secret', }), + 'output_version': 'v0', 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 37eaf500e74..3bd822187c4 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1192,6 +1192,7 @@ def test__construct_lc_result_from_responses_api_basic_text_response() -> None: ), ) + # v0 result = _construct_lc_result_from_responses_api(response) assert isinstance(result, ChatResult) @@ -1209,6 +1210,16 @@ def test__construct_lc_result_from_responses_api_basic_text_response() -> None: assert result.generations[0].message.response_metadata["id"] == "resp_123" assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o" + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + assert result.generations[0].message.content == [ + {"type": "text", "text": "Hello, world!", "annotations": [], "id": "msg_123"} + ] + assert result.generations[0].message.id == "resp_123" + assert result.generations[0].message.response_metadata["id"] == "resp_123" + def test__construct_lc_result_from_responses_api_multiple_text_blocks() -> None: """Test a response with multiple text blocks.""" @@ -1284,6 +1295,7 @@ def test__construct_lc_result_from_responses_api_multiple_messages() -> None: ], ) + # v0 result = _construct_lc_result_from_responses_api(response) assert result.generations[0].message.content == [ @@ -1297,6 +1309,23 @@ def test__construct_lc_result_from_responses_api_multiple_messages() -> None: "id": "rs_123", } } + assert result.generations[0].message.id == "msg_234" + + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + + assert result.generations[0].message.content == [ + {"type": "text", "text": "foo", "annotations": [], "id": "msg_123"}, + { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "reasoning foo"}], + "id": "rs_123", + }, + {"type": "text", "text": "bar", "annotations": [], "id": "msg_234"}, + ] + assert result.generations[0].message.id == "resp_123" def test__construct_lc_result_from_responses_api_refusal_response() -> None: @@ -1324,12 +1353,25 @@ def test__construct_lc_result_from_responses_api_refusal_response() -> None: ], ) + # v0 result = _construct_lc_result_from_responses_api(response) assert result.generations[0].message.additional_kwargs["refusal"] == ( "I cannot assist with that request." ) + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + assert result.generations[0].message.content == [ + { + "type": "refusal", + "refusal": "I cannot assist with that request.", + "id": "msg_123", + } + ] + def test__construct_lc_result_from_responses_api_function_call_valid_json() -> None: """Test a response with a valid function call.""" @@ -1352,6 +1394,7 @@ def test__construct_lc_result_from_responses_api_function_call_valid_json() -> N ], ) + # v0 result = _construct_lc_result_from_responses_api(response) msg: AIMessage = cast(AIMessage, result.generations[0].message) @@ -1368,6 +1411,22 @@ def test__construct_lc_result_from_responses_api_function_call_valid_json() -> N == "func_123" ) + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + msg = cast(AIMessage, result.generations[0].message) + assert msg.tool_calls + assert msg.content == [ + { + "type": "function_call", + "id": "func_123", + "name": "get_weather", + "arguments": '{"location": "New York", "unit": "celsius"}', + "call_id": "call_123", + } + ] + def test__construct_lc_result_from_responses_api_function_call_invalid_json() -> None: """Test a response with an invalid JSON function call.""" @@ -1444,6 +1503,7 @@ def test__construct_lc_result_from_responses_api_complex_response() -> None: user="user_123", ) + # v0 result = _construct_lc_result_from_responses_api(response) # Check message content @@ -1472,6 +1532,28 @@ def test__construct_lc_result_from_responses_api_complex_response() -> None: assert result.generations[0].message.response_metadata["status"] == "completed" assert result.generations[0].message.response_metadata["user"] == "user_123" + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + msg = cast(AIMessage, result.generations[0].message) + assert msg.response_metadata["metadata"] == {"key1": "value1", "key2": "value2"} + assert msg.content == [ + { + "type": "text", + "text": "Here's the information you requested:", + "annotations": [], + "id": "msg_123", + }, + { + "type": "function_call", + "id": "func_123", + "call_id": "call_123", + "name": "get_weather", + "arguments": '{"location": "New York"}', + }, + ] + def test__construct_lc_result_from_responses_api_no_usage_metadata() -> None: """Test a response without usage metadata.""" @@ -1525,6 +1607,7 @@ def test__construct_lc_result_from_responses_api_web_search_response() -> None: ], ) + # v0 result = _construct_lc_result_from_responses_api(response) assert "tool_outputs" in result.generations[0].message.additional_kwargs @@ -1542,6 +1625,14 @@ def test__construct_lc_result_from_responses_api_web_search_response() -> None: == "completed" ) + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + assert result.generations[0].message.content == [ + {"type": "web_search_call", "id": "websearch_123", "status": "completed"} + ] + def test__construct_lc_result_from_responses_api_file_search_response() -> None: """Test a response with file search output.""" @@ -1572,6 +1663,7 @@ def test__construct_lc_result_from_responses_api_file_search_response() -> None: ], ) + # v0 result = _construct_lc_result_from_responses_api(response) assert "tool_outputs" in result.generations[0].message.additional_kwargs @@ -1612,6 +1704,28 @@ def test__construct_lc_result_from_responses_api_file_search_response() -> None: == 0.95 ) + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + assert result.generations[0].message.content == [ + { + "type": "file_search_call", + "id": "filesearch_123", + "status": "completed", + "queries": ["python code", "langchain"], + "results": [ + { + "file_id": "file_123", + "filename": "example.py", + "score": 0.95, + "text": "def hello_world() -> None:\n print('Hello, world!')", + "attributes": {"language": "python", "size": 42}, + } + ], + } + ] + def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None: """Test a response with both web search and file search outputs.""" @@ -1656,6 +1770,7 @@ def test__construct_lc_result_from_responses_api_mixed_search_responses() -> Non ], ) + # v0 result = _construct_lc_result_from_responses_api(response) # Check message content @@ -1686,6 +1801,34 @@ def test__construct_lc_result_from_responses_api_mixed_search_responses() -> Non assert file_search["queries"] == ["python code"] assert file_search["results"][0]["filename"] == "example.py" + # responses/v1 + result = _construct_lc_result_from_responses_api( + response, output_version="responses/v1" + ) + assert result.generations[0].message.content == [ + { + "type": "text", + "text": "Here's what I found:", + "annotations": [], + "id": "msg_123", + }, + {"type": "web_search_call", "id": "websearch_123", "status": "completed"}, + { + "type": "file_search_call", + "id": "filesearch_123", + "queries": ["python code"], + "results": [ + { + "file_id": "file_123", + "filename": "example.py", + "score": 0.95, + "text": "def hello_world() -> None:\n print('Hello, world!')", + } + ], + "status": "completed", + }, + ] + def test__construct_responses_api_input_human_message_with_text_blocks_conversion() -> ( None @@ -1706,7 +1849,29 @@ def test__construct_responses_api_input_human_message_with_text_blocks_conversio def test__construct_responses_api_input_multiple_message_components() -> None: """Test that human messages with text blocks are properly converted.""" - messages: list = [ + # v0 + messages = [ + AIMessage( + content=[{"type": "text", "text": "foo"}, {"type": "text", "text": "bar"}], + id="msg_123", + response_metadata={"id": "resp_123"}, + ) + ] + result = _construct_responses_api_input(messages) + assert result == [ + { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "foo", "annotations": []}, + {"type": "output_text", "text": "bar", "annotations": []}, + ], + "id": "msg_123", + } + ] + + # responses/v1 + messages = [ AIMessage( content=[ {"type": "text", "text": "foo", "id": "msg_123"}, diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py index 6e91ff9a3fd..370adcd1f1a 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py @@ -1,7 +1,6 @@ from typing import Any, Optional from unittest.mock import MagicMock, patch -import pytest from langchain_core.messages import AIMessageChunk, BaseMessageChunk from openai.types.responses import ( ResponseCompletedEvent, @@ -601,9 +600,18 @@ responses_stream = [ ] -@pytest.mark.xfail(reason="Will be fixed with output format flags.") +def _strip_none(obj: Any) -> Any: + """Recursively strip None values from dictionaries and lists.""" + if isinstance(obj, dict): + return {k: _strip_none(v) for k, v in obj.items() if v is not None} + elif isinstance(obj, list): + return [_strip_none(v) for v in obj] + else: + return obj + + def test_responses_stream() -> None: - llm = ChatOpenAI(model="o4-mini", use_responses_api=True) + llm = ChatOpenAI(model="o4-mini", output_version="responses/v1") mock_client = MagicMock() def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: @@ -644,3 +652,20 @@ def test_responses_stream() -> None: ] assert full.content == expected_content assert full.additional_kwargs == {} + assert full.id == "resp_123" + + # Test reconstruction + payload = llm._get_request_payload([full]) + completed = [ + item + for item in responses_stream + if item.type == "response.completed" # type: ignore[attr-defined] + ] + assert len(completed) == 1 + response = completed[0].response # type: ignore[attr-defined] + + assert len(response.output) == len(payload["input"]) + for idx, item in enumerate(response.output): + dumped = _strip_none(item.model_dump()) + _ = dumped.pop("status", None) + assert dumped == payload["input"][idx] diff --git a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr index 4cd1261555c..e61b99508aa 100644 --- a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr +++ b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr @@ -10,6 +10,7 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'grok-beta', + 'output_version': 'v0', 'request_timeout': 60.0, 'stop': list([ ]),