mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-12 23:42:51 +00:00
add phase for openai
This commit is contained in:
@@ -1163,6 +1163,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
current_output_index = -1
|
||||
current_sub_index = -1
|
||||
has_reasoning = False
|
||||
item_phase_cache: dict[str, str] = {}
|
||||
for chunk in response:
|
||||
metadata = headers if is_first_chunk else {}
|
||||
(
|
||||
@@ -1179,6 +1180,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
metadata=metadata,
|
||||
has_reasoning=has_reasoning,
|
||||
output_version=self.output_version,
|
||||
item_phase_cache=item_phase_cache,
|
||||
)
|
||||
if generation_chunk:
|
||||
if run_manager:
|
||||
@@ -1218,6 +1220,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
current_output_index = -1
|
||||
current_sub_index = -1
|
||||
has_reasoning = False
|
||||
item_phase_cache: dict[str, str] = {}
|
||||
async for chunk in response:
|
||||
metadata = headers if is_first_chunk else {}
|
||||
(
|
||||
@@ -1234,6 +1237,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
metadata=metadata,
|
||||
has_reasoning=has_reasoning,
|
||||
output_version=self.output_version,
|
||||
item_phase_cache=item_phase_cache,
|
||||
)
|
||||
if generation_chunk:
|
||||
if run_manager:
|
||||
@@ -4086,14 +4090,15 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
break
|
||||
else:
|
||||
# If no block with this ID, create a new one
|
||||
input_.append(
|
||||
{
|
||||
"type": "message",
|
||||
"content": [new_block],
|
||||
"role": "assistant",
|
||||
"id": msg_id,
|
||||
}
|
||||
)
|
||||
new_item: dict = {
|
||||
"type": "message",
|
||||
"content": [new_block],
|
||||
"role": "assistant",
|
||||
"id": msg_id,
|
||||
}
|
||||
if phase := block.get("phase"):
|
||||
new_item["phase"] = phase
|
||||
input_.append(new_item)
|
||||
elif block_type in (
|
||||
"reasoning",
|
||||
"web_search_call",
|
||||
@@ -4242,6 +4247,7 @@ def _construct_lc_result_from_responses_api(
|
||||
additional_kwargs: dict = {}
|
||||
for output in response.output:
|
||||
if output.type == "message":
|
||||
phase = getattr(output, "phase", None)
|
||||
for content in output.content:
|
||||
if content.type == "output_text":
|
||||
block = {
|
||||
@@ -4255,13 +4261,20 @@ def _construct_lc_result_from_responses_api(
|
||||
else [],
|
||||
"id": output.id,
|
||||
}
|
||||
if phase is not None:
|
||||
block["phase"] = phase
|
||||
content_blocks.append(block)
|
||||
if hasattr(content, "parsed"):
|
||||
additional_kwargs["parsed"] = content.parsed
|
||||
if content.type == "refusal":
|
||||
content_blocks.append(
|
||||
{"type": "refusal", "refusal": content.refusal, "id": output.id}
|
||||
)
|
||||
refusal_block = {
|
||||
"type": "refusal",
|
||||
"refusal": content.refusal,
|
||||
"id": output.id,
|
||||
}
|
||||
if phase is not None:
|
||||
refusal_block["phase"] = phase
|
||||
content_blocks.append(refusal_block)
|
||||
elif output.type == "function_call":
|
||||
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
|
||||
try:
|
||||
@@ -4368,6 +4381,7 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
metadata: dict | None = None,
|
||||
has_reasoning: bool = False,
|
||||
output_version: str | None = None,
|
||||
item_phase_cache: dict[str, str] | None = None,
|
||||
) -> tuple[int, int, int, ChatGenerationChunk | None]:
|
||||
def _advance(output_idx: int, sub_idx: int | None = None) -> None:
|
||||
"""Advance indexes tracked during streaming.
|
||||
@@ -4447,14 +4461,15 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
)
|
||||
elif chunk.type == "response.output_text.done":
|
||||
_advance(chunk.output_index, chunk.content_index)
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
"id": chunk.item_id,
|
||||
"index": current_index,
|
||||
}
|
||||
)
|
||||
block = {
|
||||
"type": "text",
|
||||
"text": "",
|
||||
"id": chunk.item_id,
|
||||
"index": current_index,
|
||||
}
|
||||
if item_phase_cache and (phase := item_phase_cache.get(chunk.item_id)):
|
||||
block["phase"] = phase
|
||||
content.append(block)
|
||||
elif chunk.type == "response.created":
|
||||
id = chunk.response.id
|
||||
response_metadata["id"] = chunk.response.id # Backwards compatibility
|
||||
@@ -4479,8 +4494,10 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
elif chunk.type == "response.output_item.added" and chunk.item.type == "message":
|
||||
if output_version == "v0":
|
||||
id = chunk.item.id
|
||||
else:
|
||||
pass
|
||||
if item_phase_cache is not None and (
|
||||
phase := getattr(chunk.item, "phase", None)
|
||||
):
|
||||
item_phase_cache[chunk.item.id] = phase
|
||||
elif (
|
||||
chunk.type == "response.output_item.added"
|
||||
and chunk.item.type == "function_call"
|
||||
|
||||
@@ -3226,3 +3226,66 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None:
|
||||
pass
|
||||
except ValueError as e:
|
||||
pytest.fail(f"This is a wrong behavior. Error details: {e}")
|
||||
|
||||
|
||||
def test__construct_responses_api_input_preserves_phase() -> None:
|
||||
"""Test that phase is preserved on assistant message items during roundtrip."""
|
||||
messages: list = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thinking...",
|
||||
"id": "msg_001",
|
||||
"phase": "commentary",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "final answer",
|
||||
"id": "msg_002",
|
||||
"phase": "final_answer",
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
result = _construct_responses_api_input(messages)
|
||||
assert result[0]["phase"] == "commentary"
|
||||
assert result[1]["phase"] == "final_answer"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_no_phase_when_absent() -> None:
|
||||
"""Test that phase is not added to assistant message items when not present."""
|
||||
messages: list = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "hello", "id": "msg_123"},
|
||||
]
|
||||
)
|
||||
]
|
||||
result = _construct_responses_api_input(messages)
|
||||
assert "phase" not in result[0]
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_captures_phase() -> None:
|
||||
"""Test that phase from output message is stored on content blocks."""
|
||||
output_item = MagicMock()
|
||||
output_item.type = "message"
|
||||
output_item.phase = "commentary"
|
||||
output_item.id = "msg_001"
|
||||
content_block = MagicMock()
|
||||
content_block.type = "output_text"
|
||||
content_block.text = "thinking"
|
||||
content_block.annotations = []
|
||||
output_item.content = [content_block]
|
||||
response = MagicMock()
|
||||
response.error = None
|
||||
response.id = "resp_001"
|
||||
response.output = [output_item]
|
||||
response.usage = None
|
||||
response.model_dump.return_value = {}
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
msg = result.generations[0].message
|
||||
assert isinstance(msg.content, list)
|
||||
block = cast(dict, msg.content[0])
|
||||
assert block["phase"] == "commentary"
|
||||
|
||||
Reference in New Issue
Block a user