add phase for openai

This commit is contained in:
Nick Huang
2026-03-20 14:36:22 -04:00
parent 72571185a8
commit 6d1d2b0363
2 changed files with 101 additions and 21 deletions

View File

@@ -1163,6 +1163,7 @@ class BaseChatOpenAI(BaseChatModel):
current_output_index = -1
current_sub_index = -1
has_reasoning = False
item_phase_cache: dict[str, str] = {}
for chunk in response:
metadata = headers if is_first_chunk else {}
(
@@ -1179,6 +1180,7 @@ class BaseChatOpenAI(BaseChatModel):
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
item_phase_cache=item_phase_cache,
)
if generation_chunk:
if run_manager:
@@ -1218,6 +1220,7 @@ class BaseChatOpenAI(BaseChatModel):
current_output_index = -1
current_sub_index = -1
has_reasoning = False
item_phase_cache: dict[str, str] = {}
async for chunk in response:
metadata = headers if is_first_chunk else {}
(
@@ -1234,6 +1237,7 @@ class BaseChatOpenAI(BaseChatModel):
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
item_phase_cache=item_phase_cache,
)
if generation_chunk:
if run_manager:
@@ -4086,14 +4090,15 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
break
else:
# If no block with this ID, create a new one
input_.append(
{
"type": "message",
"content": [new_block],
"role": "assistant",
"id": msg_id,
}
)
new_item: dict = {
"type": "message",
"content": [new_block],
"role": "assistant",
"id": msg_id,
}
if phase := block.get("phase"):
new_item["phase"] = phase
input_.append(new_item)
elif block_type in (
"reasoning",
"web_search_call",
@@ -4242,6 +4247,7 @@ def _construct_lc_result_from_responses_api(
additional_kwargs: dict = {}
for output in response.output:
if output.type == "message":
phase = getattr(output, "phase", None)
for content in output.content:
if content.type == "output_text":
block = {
@@ -4255,13 +4261,20 @@ def _construct_lc_result_from_responses_api(
else [],
"id": output.id,
}
if phase is not None:
block["phase"] = phase
content_blocks.append(block)
if hasattr(content, "parsed"):
additional_kwargs["parsed"] = content.parsed
if content.type == "refusal":
content_blocks.append(
{"type": "refusal", "refusal": content.refusal, "id": output.id}
)
refusal_block = {
"type": "refusal",
"refusal": content.refusal,
"id": output.id,
}
if phase is not None:
refusal_block["phase"] = phase
content_blocks.append(refusal_block)
elif output.type == "function_call":
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
try:
@@ -4368,6 +4381,7 @@ def _convert_responses_chunk_to_generation_chunk(
metadata: dict | None = None,
has_reasoning: bool = False,
output_version: str | None = None,
item_phase_cache: dict[str, str] | None = None,
) -> tuple[int, int, int, ChatGenerationChunk | None]:
def _advance(output_idx: int, sub_idx: int | None = None) -> None:
"""Advance indexes tracked during streaming.
@@ -4447,14 +4461,15 @@ def _convert_responses_chunk_to_generation_chunk(
)
elif chunk.type == "response.output_text.done":
_advance(chunk.output_index, chunk.content_index)
content.append(
{
"type": "text",
"text": "",
"id": chunk.item_id,
"index": current_index,
}
)
block = {
"type": "text",
"text": "",
"id": chunk.item_id,
"index": current_index,
}
if item_phase_cache and (phase := item_phase_cache.get(chunk.item_id)):
block["phase"] = phase
content.append(block)
elif chunk.type == "response.created":
id = chunk.response.id
response_metadata["id"] = chunk.response.id # Backwards compatibility
@@ -4479,8 +4494,10 @@ def _convert_responses_chunk_to_generation_chunk(
elif chunk.type == "response.output_item.added" and chunk.item.type == "message":
if output_version == "v0":
id = chunk.item.id
else:
pass
if item_phase_cache is not None and (
phase := getattr(chunk.item, "phase", None)
):
item_phase_cache[chunk.item.id] = phase
elif (
chunk.type == "response.output_item.added"
and chunk.item.type == "function_call"

View File

@@ -3226,3 +3226,66 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None:
pass
except ValueError as e:
pytest.fail(f"This is a wrong behavior. Error details: {e}")
def test__construct_responses_api_input_preserves_phase() -> None:
"""Test that phase is preserved on assistant message items during roundtrip."""
messages: list = [
AIMessage(
content=[
{
"type": "text",
"text": "thinking...",
"id": "msg_001",
"phase": "commentary",
},
{
"type": "text",
"text": "final answer",
"id": "msg_002",
"phase": "final_answer",
},
]
)
]
result = _construct_responses_api_input(messages)
assert result[0]["phase"] == "commentary"
assert result[1]["phase"] == "final_answer"
def test__construct_responses_api_input_no_phase_when_absent() -> None:
"""Test that phase is not added to assistant message items when not present."""
messages: list = [
AIMessage(
content=[
{"type": "text", "text": "hello", "id": "msg_123"},
]
)
]
result = _construct_responses_api_input(messages)
assert "phase" not in result[0]
def test__construct_lc_result_from_responses_api_captures_phase() -> None:
"""Test that phase from output message is stored on content blocks."""
output_item = MagicMock()
output_item.type = "message"
output_item.phase = "commentary"
output_item.id = "msg_001"
content_block = MagicMock()
content_block.type = "output_text"
content_block.text = "thinking"
content_block.annotations = []
output_item.content = [content_block]
response = MagicMock()
response.error = None
response.id = "resp_001"
response.output = [output_item]
response.usage = None
response.model_dump.return_value = {}
result = _construct_lc_result_from_responses_api(response)
msg = result.generations[0].message
assert isinstance(msg.content, list)
block = cast(dict, msg.content[0])
assert block["phase"] == "commentary"