diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 6d41c83f331..89d9eaec98c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1163,6 +1163,7 @@ class BaseChatOpenAI(BaseChatModel): current_output_index = -1 current_sub_index = -1 has_reasoning = False + item_phase_cache: dict[str, str] = {} for chunk in response: metadata = headers if is_first_chunk else {} ( @@ -1179,6 +1180,7 @@ class BaseChatOpenAI(BaseChatModel): metadata=metadata, has_reasoning=has_reasoning, output_version=self.output_version, + item_phase_cache=item_phase_cache, ) if generation_chunk: if run_manager: @@ -1218,6 +1220,7 @@ class BaseChatOpenAI(BaseChatModel): current_output_index = -1 current_sub_index = -1 has_reasoning = False + item_phase_cache: dict[str, str] = {} async for chunk in response: metadata = headers if is_first_chunk else {} ( @@ -1234,6 +1237,7 @@ class BaseChatOpenAI(BaseChatModel): metadata=metadata, has_reasoning=has_reasoning, output_version=self.output_version, + item_phase_cache=item_phase_cache, ) if generation_chunk: if run_manager: @@ -4086,14 +4090,15 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: break else: # If no block with this ID, create a new one - input_.append( - { - "type": "message", - "content": [new_block], - "role": "assistant", - "id": msg_id, - } - ) + new_item: dict = { + "type": "message", + "content": [new_block], + "role": "assistant", + "id": msg_id, + } + if phase := block.get("phase"): + new_item["phase"] = phase + input_.append(new_item) elif block_type in ( "reasoning", "web_search_call", @@ -4242,6 +4247,7 @@ def _construct_lc_result_from_responses_api( additional_kwargs: dict = {} for output in response.output: if output.type == "message": + phase = getattr(output, "phase", None) for content in output.content: if content.type == "output_text": block = { @@ -4255,13 +4261,20 @@ def _construct_lc_result_from_responses_api( else [], "id": output.id, } + if phase is not None: + block["phase"] = phase content_blocks.append(block) if hasattr(content, "parsed"): additional_kwargs["parsed"] = content.parsed if content.type == "refusal": - content_blocks.append( - {"type": "refusal", "refusal": content.refusal, "id": output.id} - ) + refusal_block = { + "type": "refusal", + "refusal": content.refusal, + "id": output.id, + } + if phase is not None: + refusal_block["phase"] = phase + content_blocks.append(refusal_block) elif output.type == "function_call": content_blocks.append(output.model_dump(exclude_none=True, mode="json")) try: @@ -4368,6 +4381,7 @@ def _convert_responses_chunk_to_generation_chunk( metadata: dict | None = None, has_reasoning: bool = False, output_version: str | None = None, + item_phase_cache: dict[str, str] | None = None, ) -> tuple[int, int, int, ChatGenerationChunk | None]: def _advance(output_idx: int, sub_idx: int | None = None) -> None: """Advance indexes tracked during streaming. @@ -4447,14 +4461,15 @@ def _convert_responses_chunk_to_generation_chunk( ) elif chunk.type == "response.output_text.done": _advance(chunk.output_index, chunk.content_index) - content.append( - { - "type": "text", - "text": "", - "id": chunk.item_id, - "index": current_index, - } - ) + block = { + "type": "text", + "text": "", + "id": chunk.item_id, + "index": current_index, + } + if item_phase_cache and (phase := item_phase_cache.get(chunk.item_id)): + block["phase"] = phase + content.append(block) elif chunk.type == "response.created": id = chunk.response.id response_metadata["id"] = chunk.response.id # Backwards compatibility @@ -4479,8 +4494,10 @@ def _convert_responses_chunk_to_generation_chunk( elif chunk.type == "response.output_item.added" and chunk.item.type == "message": if output_version == "v0": id = chunk.item.id - else: - pass + if item_phase_cache is not None and ( + phase := getattr(chunk.item, "phase", None) + ): + item_phase_cache[chunk.item.id] = phase elif ( chunk.type == "response.output_item.added" and chunk.item.type == "function_call" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 79fbc0222f8..1a275323570 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -3226,3 +3226,66 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None: pass except ValueError as e: pytest.fail(f"This is a wrong behavior. Error details: {e}") + + +def test__construct_responses_api_input_preserves_phase() -> None: + """Test that phase is preserved on assistant message items during roundtrip.""" + messages: list = [ + AIMessage( + content=[ + { + "type": "text", + "text": "thinking...", + "id": "msg_001", + "phase": "commentary", + }, + { + "type": "text", + "text": "final answer", + "id": "msg_002", + "phase": "final_answer", + }, + ] + ) + ] + result = _construct_responses_api_input(messages) + assert result[0]["phase"] == "commentary" + assert result[1]["phase"] == "final_answer" + + +def test__construct_responses_api_input_no_phase_when_absent() -> None: + """Test that phase is not added to assistant message items when not present.""" + messages: list = [ + AIMessage( + content=[ + {"type": "text", "text": "hello", "id": "msg_123"}, + ] + ) + ] + result = _construct_responses_api_input(messages) + assert "phase" not in result[0] + + +def test__construct_lc_result_from_responses_api_captures_phase() -> None: + """Test that phase from output message is stored on content blocks.""" + output_item = MagicMock() + output_item.type = "message" + output_item.phase = "commentary" + output_item.id = "msg_001" + content_block = MagicMock() + content_block.type = "output_text" + content_block.text = "thinking" + content_block.annotations = [] + output_item.content = [content_block] + response = MagicMock() + response.error = None + response.id = "resp_001" + response.output = [output_item] + response.usage = None + response.model_dump.return_value = {} + + result = _construct_lc_result_from_responses_api(response) + msg = result.generations[0].message + assert isinstance(msg.content, list) + block = cast(dict, msg.content[0]) + assert block["phase"] == "commentary"